fujie_code/utils/utils_fit.py

152 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import torch
from tqdm import tqdm
from utils.utils import get_lr
def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step,
epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
loss = 0
val_loss = 0
if local_rank == 0:
print('Start Train')
pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
model_train.train() # 调整所有的模块为train模式
for iteration, batch in enumerate(gen):
if iteration >= epoch_step: # 有什么意义?
break
images, targets = batch[0], batch[1] # targets也是归一化了的
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = [ann.cuda(local_rank) for ann in
targets] # targets是一个python的list里面是tensor把tensor逐个转到cuda上然后targets还是python的列表
# ----------------------#
# 清零梯度
# ----------------------#
optimizer.zero_grad()
if not fp16:
# ----------------------#
# 前向传播
# ----------------------#
outputs = model_train(images)
loss_value_all = 0
# ----------------------#
# 计算损失
# ----------------------#
for l in range(len(outputs)): # 三组不同分辨率大小的输出特征分别计算
loss_item = yolo_loss(l, outputs[l], targets)
loss_value_all += loss_item
loss_value = loss_value_all
# ----------------------#
# 反向传播
# ----------------------#
loss_value.backward()
optimizer.step()
else: # 不进入这条分支
from torch.cuda.amp import autocast
with autocast():
# ----------------------#
# 前向传播
# ----------------------#
outputs = model_train(images)
loss_value_all = 0
# ----------------------#
# 计算损失
# ----------------------#
for l in range(len(outputs)):
loss_item = yolo_loss(l, outputs[l], targets)
loss_value_all += loss_item
loss_value = loss_value_all
# ----------------------#
# 反向传播
# ----------------------#
scaler.scale(loss_value).backward()
scaler.step(optimizer)
scaler.update()
loss += loss_value.item()
# # 调试用 begin
# if iteration > 2:
# break
# # 调试用 end
if local_rank == 0:
pbar.set_postfix(**{'loss': loss / (iteration + 1),
'lr': get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Train')
print('Start Validation')
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
model_train.eval()
for iteration, batch in enumerate(gen_val):
if iteration >= epoch_step_val:
break
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = [ann.cuda(local_rank) for ann in targets]
# ----------------------#
# 清零梯度
# ----------------------#
optimizer.zero_grad()
# ----------------------#
# 前向传播
# ----------------------#
outputs = model_train(images)
loss_value_all = 0
# ----------------------#
# 计算损失
# ----------------------#
for l in range(len(outputs)):
loss_item = yolo_loss(l, outputs[l], targets)
loss_value_all += loss_item
loss_value = loss_value_all
val_loss += loss_value.item()
# # 调试用 begin
# if iteration > 2:
# break
# # 调试用 end
if local_rank == 0:
pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Validation')
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
eval_callback.on_epoch_end(epoch + 1, model_train)
print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
# -----------------------------------------------#
# 保存权值
# -----------------------------------------------#
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (
epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
print('Save best model to best_epoch_weights.pth')
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))