-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
216 lines (191 loc) · 10.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#-------------------------------------#
# 对数据集进行训练
#-------------------------------------#
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from loguru import logger
from nets.yolo import YOLOX
from nets.yolo_training import YOLOLoss, weights_init
from utils.callbacks import LossHistory
from utils.dataloader import YoloDataset, yolo_dataset_collate
from utils.utils import get_classes
from utils.utils_fit import fit_one_epoch
from utils.ema import ModelEMA
if __name__ == "__main__":
# 剪枝微调训练
pruned_train = False
# 是否开启EMA(指数移动平均)训练 会多占用一点显存
EMA = False
Cuda = True
#--------------------------------------------------------#
# 训练前一定要修改classes_path,使其对应自己的数据集
#--------------------------------------------------------#
classes_path = 'model_data/coco_classes.txt'
#-------------------------------------------------------------------------------------#
# 权值文件请看README,百度网盘下载
# 预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显
# 网络训练的结果也不会好,数据的预训练权重对不同数据集是通用的,因为特征是通用的
#------------------------------------------------------------------------------------#
model_path = 'model_data/yolox_s.pth'
#---------------------------------------------------------------------#
# 所使用的YoloX的版本。s、m、l、x
#---------------------------------------------------------------------#
phi = 's'
#------------------------------------------------------#
# 输入的shape大小,一定要是32的倍数
#------------------------------------------------------#
input_shape = [640, 640]
#------------------------------------------------------------------------------------------------------------#
# Yolov4的tricks应用
# mosaic 马赛克数据增强 True or False
# YOLOX作者强调要在训练结束前的N个epoch关掉Mosaic。因为Mosaic生成的训练图片,远远脱离自然图片的真实分布。
# 并且Mosaic大量的crop操作会带来很多不准确的标注框,本代码自动会在前90%个epoch使用mosaic,后面不使用。
# Cosine_scheduler 余弦退火学习率 True or False
#------------------------------------------------------------------------------------------------------------#
mosaic = False
Cosine_scheduler = False
#----------------------------------------------------#
# 训练分为两个阶段,分别是冻结阶段和解冻阶段
# 冻结阶段训练参数
# 此时模型的主干被冻结了,特征提取网络不发生改变
# 占用的显存较小,仅对网络进行微调
#----------------------------------------------------#
Init_Epoch = 0
Freeze_Epoch = 50
Freeze_batch_size = 8
Freeze_lr = 1e-3
#----------------------------------------------------#
# 解冻阶段训练参数
# 此时模型的主干不被冻结了,特征提取网络会发生改变
# 占用的显存较大,网络所有的参数都会发生改变
#----------------------------------------------------#
UnFreeze_Epoch = 100
Unfreeze_batch_size = 4
Unfreeze_lr = 1e-4
#------------------------------------------------------#
# 是否进行冻结训练,默认先冻结主干训练后解冻训练。
#------------------------------------------------------#
Freeze_Train = True
#------------------------------------------------------#
# 用于设置是否使用多线程读取数据
# 开启后会加快数据读取速度,但是会占用更多内存
# 内存较小的电脑可以设置为2或者0
#------------------------------------------------------#
num_workers = 4
#----------------------------------------------------#
# 获得图片路径和标签
#----------------------------------------------------#
train_annotation_path = '2007_train.txt'
val_annotation_path = '2007_val.txt'
# ---------------------------#
# 读取数据集对应的txt
# ---------------------------#
with open(train_annotation_path) as f:
train_lines = f.readlines()
with open(val_annotation_path) as f:
val_lines = f.readlines()
num_train = len(train_lines)
num_val = len(val_lines)
#----------------------------------------------------#
# 获取classes和anchor
#----------------------------------------------------#
class_names, num_classes = get_classes(classes_path)
#------------------------------------------------------#
# 创建yolo模型
#------------------------------------------------------#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if pruned_train:
model = torch.load(model_path, map_location=device)
else:
model = YOLOX(num_classes, phi)
weights_init(model)
#------------------------------------------------------#
# 权值文件请看README,百度网盘下载
#------------------------------------------------------#
logger.info('Load weights {}.'.format(model_path))
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location = device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if EMA:
ema_model = ModelEMA(model, 0.9998)
ema_model.updates = Init_Epoch * num_train
model_train = model.train()
if Cuda:
model_train = torch.nn.DataParallel(model)
cudnn.benchmark = True
model_train = model_train.cuda()
yolo_loss = YOLOLoss(num_classes)
loss_history = LossHistory("logs/")
#------------------------------------------------------#
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
# 也可以在训练初期防止权值被破坏。
# Init_Epoch为起始世代
# Freeze_Epoch为冻结训练的世代
# UnFreeze_Epoch总训练世代
# 提示OOM或者显存不足请调小Batch_size
#------------------------------------------------------#
if True:
batch_size = Freeze_batch_size
lr = Freeze_lr
start_epoch = Init_Epoch
end_epoch = Freeze_Epoch
optimizer = optim.Adam(model_train.parameters(), lr, weight_decay = 5e-4)
if Cosine_scheduler:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)
train_dataset = YoloDataset(train_lines, input_shape, num_classes, end_epoch - start_epoch, mosaic=mosaic, train=True)
val_dataset = YoloDataset(val_lines, input_shape, num_classes, end_epoch - start_epoch, mosaic=False, train=False)
gen = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
#------------------------------------#
# 冻结一定部分训练
#------------------------------------#
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = False
for epoch in range(start_epoch, end_epoch):
fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
epoch_step, epoch_step_val, gen, gen_val, end_epoch, Cuda)
lr_scheduler.step()
if True:
batch_size = Unfreeze_batch_size
lr = Unfreeze_lr
start_epoch = Freeze_Epoch
end_epoch = UnFreeze_Epoch
optimizer = optim.Adam(model_train.parameters(), lr, weight_decay = 5e-4)
if Cosine_scheduler:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)
train_dataset = YoloDataset(train_lines, input_shape, num_classes, end_epoch - start_epoch, mosaic = mosaic, train = True)
val_dataset = YoloDataset(val_lines, input_shape, num_classes, end_epoch - start_epoch, mosaic = False, train = False)
gen = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
#------------------------------------#
# 解冻后训练
#------------------------------------#
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = True
for epoch in range(start_epoch, end_epoch):
fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
epoch_step, epoch_step_val, gen, gen_val, end_epoch, Cuda)
lr_scheduler.step()