fujie_code/nets/yolo_training.py

489 lines
30 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import math
from functools import partial
import numpy as np
import torch
import torch.nn as nn
class YOLOLoss(nn.Module):
def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]):
super(YOLOLoss, self).__init__()
# -----------------------------------------------------------#
# 13x13的特征层对应的anchor是[116,90],[156,198],[373,326]
# 26x26的特征层对应的anchor是[30,61],[62,45],[59,119]
# 52x52的特征层对应的anchor是[10,13],[16,30],[33,23]
# -----------------------------------------------------------#
self.anchors = anchors
self.num_classes = num_classes
self.bbox_attrs = 5 + num_classes
self.input_shape = input_shape
self.anchors_mask = anchors_mask
self.giou = True
self.balance = [0.4, 1.0, 4]
self.box_ratio = 0.05
self.obj_ratio = 5 * (input_shape[0] * input_shape[1]) / (416 ** 2)
self.cls_ratio = 1 * (num_classes / 80)
self.ignore_threshold = 0.5
self.cuda = cuda
def clip_by_tensor(self, t, t_min, t_max):
t = t.float()
result = (t >= t_min).float() * t + (t < t_min).float() * t_min # 要么是t要么是t_min
result = (result <= t_max).float() * result + (result > t_max).float() * t_max
return result
def MSELoss(self, pred, target):
return torch.pow(pred - target, 2)
def BCELoss(self, pred, target):
epsilon = 1e-7
pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon) # 保证tensor在 epsilon和1.0 - epsilon之间
output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
return output
def box_giou(self, b1, b2):
"""
输入为:
----------
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
返回为:
-------
giou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
"""
# ----------------------------------------------------#
# 求出预测框左上角右下角
# ----------------------------------------------------#
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh / 2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
# ----------------------------------------------------#
# 求出真实框左上角右下角
# ----------------------------------------------------#
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh / 2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
# ----------------------------------------------------#
# 求真实框和预测框所有的iou
# ----------------------------------------------------#
intersect_mins = torch.max(b1_mins, b2_mins)
intersect_maxes = torch.min(b1_maxes, b2_maxes)
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / union_area
# ----------------------------------------------------#
# 找到包裹两个框的最小框的左上角和右下角
# ----------------------------------------------------#
enclose_mins = torch.min(b1_mins, b2_mins)
enclose_maxes = torch.max(b1_maxes, b2_maxes)
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
# ----------------------------------------------------#
# 计算对角线距离
# ----------------------------------------------------#
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
giou = iou - (enclose_area - union_area) / enclose_area
return giou
def forward(self, l, input, targets=None):
# ----------------------------------------------------#
# l代表的是当前输入进来的有效特征层是第几个有效特征层
# input的shape为 bs, 3*(5+num_classes), 13, 13
# bs, 3*(5+num_classes), 26, 26
# bs, 3*(5+num_classes), 52, 52
# targets代表的是真实框。
# ----------------------------------------------------#
# --------------------------------#
# 获得图片数量,特征层的高和宽
# 13和13
# --------------------------------#
bs = input.size(0)
in_h = input.size(2)
in_w = input.size(3)
# -----------------------------------------------------------------------#
# 计算步长
# 每一个特征点对应原来的图片上多少个像素点
# 如果特征层为13x13的话一个特征点就对应原来的图片上的32个像素点
# 如果特征层为26x26的话一个特征点就对应原来的图片上的16个像素点
# 如果特征层为52x52的话一个特征点就对应原来的图片上的8个像素点
# stride_h = stride_w = 32、16、8
# stride_h和stride_w都是32。
# -----------------------------------------------------------------------#
stride_h = self.input_shape[0] / in_h
stride_w = self.input_shape[1] / in_w
# -------------------------------------------------#
# 把anchor转换到此时获得的scaled_anchors大小是相对于特征层的
# -------------------------------------------------#
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors] # 把anchor也缩放到与输出特征图相同尺度
# -----------------------------------------------#
# 输入的input一共有三个他们的shape分别是
# bs, 3*(5+num_classes), 13, 13 => batch_size, 3, 13, 13, 5 + num_classes
# batch_size, 3, 26, 26, 5 + num_classes
# batch_size, 3, 52, 52, 5 + num_classes
# -----------------------------------------------#
prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(
0, 1, 3, 4, 2).contiguous() # batch_size, 3种anchor, h, w, 单个anchor对应的25个输出值
# -----------------------------------------------#
# 先验框的中心位置的调整参数
# -----------------------------------------------#
x = torch.sigmoid(prediction[..., 0]) # prediction[..., 0] 维度是8, 3, 13, 13 取tx坐标
y = torch.sigmoid(prediction[..., 1]) # ty
# -----------------------------------------------#
# 先验框的宽高调整参数
# -----------------------------------------------#
w = prediction[..., 2] # tw
h = prediction[..., 3] # th
# -----------------------------------------------#
# 获得置信度,是否有物体
# -----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4]) # prediction[..., 4] 是否有目标
# -----------------------------------------------#
# 种类置信度
# -----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:])
# -----------------------------------------------#
# 获得网络应该有的预测结果 y_true是重新建立的真实标签 8, 3, 13, 13, 25. noobj_mask中有目标为0其他为1. box_loss_scale记录了面积
# -----------------------------------------------#
y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
# y_true中是用 真实框转换为 与网络输出一致的格式。比如坐标是在输出特征分辨率下的类别是真实框所在的cell对应的类别。
# ---------------------------------------------------------------#
# 将预测结果进行解码,判断预测结果和真实值的重合程度
# 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
# 作为负样本不合适 # l在这里是三个多尺度特征图的第几个 pred_boxes是生成的网络预测的结果
# ----------------------------------------------------------------#
noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
if self.cuda:
y_true = y_true.type_as(x)
noobj_mask = noobj_mask.type_as(x)
box_loss_scale = box_loss_scale.type_as(x)
# --------------------------------------------------------------------------#
# box_loss_scale是真实框宽高的乘积宽高均在0-1之间因此乘积也在0-1之间。
# 2-宽高的乘积代表真实框越大,比重越小,小框的比重更大。
# --------------------------------------------------------------------------#
box_loss_scale = 2 - box_loss_scale
loss = 0
obj_mask = y_true[..., 4] == 1
n = torch.sum(obj_mask)
if n != 0:
if self.giou:
# ---------------------------------------------------------------#
# 计算预测结果和真实结果的giou
# ----------------------------------------------------------------#
giou = self.box_giou(pred_boxes, y_true[..., :4]).type_as(x)
loss_loc = torch.mean((1 - giou)[obj_mask]) # 这里用的GIOU 作为定位误差而不是论文中的MSE
else:
# -----------------------------------------------------------#
# 计算中心偏移情况的loss使用BCELoss效果好一些
# -----------------------------------------------------------#
loss_x = torch.mean(self.BCELoss(x[obj_mask], y_true[..., 0][obj_mask]) * box_loss_scale[obj_mask])
loss_y = torch.mean(self.BCELoss(y[obj_mask], y_true[..., 1][obj_mask]) * box_loss_scale[obj_mask])
# -----------------------------------------------------------#
# 计算宽高调整值的loss
# -----------------------------------------------------------#
loss_w = torch.mean(self.MSELoss(w[obj_mask], y_true[..., 2][obj_mask]) * box_loss_scale[obj_mask])
loss_h = torch.mean(self.MSELoss(h[obj_mask], y_true[..., 3][obj_mask]) * box_loss_scale[obj_mask])
loss_loc = (loss_x + loss_y + loss_h + loss_w) * 0.1
# pred_cls[obj_mask] 有目标的框数* 20个属性值20个分类
loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask])) # 目标的分类误差
loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
loss_conf = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask]) # 忽略掉部分重叠高的但不是最匹配的预测框 的是否有目标的误差
loss += loss_conf * self.balance[l] * self.obj_ratio # self.balance[l]不同层的权重不一样 [0.4, 1.0, 4] 表示对小目标损失权重更大
# if n != 0:
# print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
return loss
def calculate_iou(self, _box_a, _box_b):
# -----------------------------------------------------------#
# 计算真实框的左上角和右下角 以0,0为中心点计算左上角和右下角
# -----------------------------------------------------------#
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
# -----------------------------------------------------------#
# 计算先验框获得的预测框的左上角和右下角
# -----------------------------------------------------------#
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
# -----------------------------------------------------------#
# 将真实框和预测框都转化成左上角右下角的形式
# -----------------------------------------------------------#
box_a = torch.zeros_like(_box_a)
box_b = torch.zeros_like(_box_b)
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
# ----------------------------------------------------------- #
# A为真实框的数量B为先验框的数量
# ----------------------------------------------------------- #
A = box_a.size(0)
B = box_b.size(0)
# ----------------------------------------------------------- #
# 计算交的面积 box_a是真实框左上角和右下角。 box_b是先验框的左上角和右下角
# box_a[:, 2:].unsqueeze(1).expand(A, B, 2) 从 5, 1, 2 扩展到5, 9, 2。 这里的5是图中框的数量。每一个组有9个5个框重复9次
# box_b[:, 2:].unsqueeze(0).expand(A, B, 2) 从 1, 9, 2 扩展到5, 9, 2。 这里的每一个组9个是不一样的9个anchor框重复5次。
# ----------------------------------------------------------- #
# 每一个gt复制 len(anchors)次然后与所有anchors比较
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), # 计算右下角的最小点
box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) # 输出 592
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), # 计算左上角的最大点
box_b[:, :2].unsqueeze(0).expand(A, B, 2)) # 输出 592
inter = torch.clamp((max_xy - min_xy), # 这里无法判断两个框不相交的情况。但不相交 U 就大,所以应该不影响结果
min=0) # 最小值是0最大值不限。相减之后得到宽和高。# input输入张量 min范围的最小值如果不指定的话会默认无下界 max范围的最大值如果不指定的话会默认无上界
inter = inter[:, :, 0] * inter[:, :, 1] # 每个真实框与锚框 相交的面积
# ----------------------------------------------------------- #
# 计算预测框和真实框各自的面积
# ----------------------------------------------------------- #
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(
inter) # [A,B] 5个值重复9次
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(
inter) # [A,B] 9个值重复5次
# ----------------------------------------------------------- #
# 求IOU
# ----------------------------------------------------------- #
union = area_a + area_b - inter
return inter / union # [A,B]
def get_target(self, l, targets, anchors, in_h, in_w):
# -----------------------------------------------------#
# 计算一共有多少张图片
# -----------------------------------------------------#
bs = len(targets)
# -----------------------------------------------------#
# 对每一个grid cell都需要标记。用于选取哪些先验框不包含物体
# -----------------------------------------------------#
noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad=False)
# -----------------------------------------------------#
# 让网络更加去关注小目标
# -----------------------------------------------------#
box_loss_scale = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad=False)
# -----------------------------------------------------#
# batch_size, 3, 13, 13, 5 + num_classes
# -----------------------------------------------------#
y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad=False)
for b in range(bs): # 每张图片单独计算
if len(targets[b]) == 0: # targets是真实框
continue
batch_target = torch.zeros_like(targets[b]) # 把0~1之间的targets转换到 特征图大小的 targets
# -------------------------------------------------------#
# 计算出正样本在特征层上的中心点 # box第01维记录中心点 box第23维记录宽高 # 这里不知道为何这样做,但结果一样的
# -------------------------------------------------------#
batch_target[:, [0, 2]] = targets[b][:, [0, 2]] * in_w # 从归一化的box中反解出在 13*13 分辨率下的大小 两个 x 坐标
batch_target[:, [1, 3]] = targets[b][:, [1, 3]] * in_h
batch_target[:, 4] = targets[b][:, 4]
batch_target = batch_target.cpu() # 因为是从targets放在cuda上中复制过来的所以需要执行一次cpu()
# -------------------------------------------------------#
# 将真实框转换一个形式 相当于都放到0, 0, w, h 进行比较
# num_true_box, 4 # 把23 维也就是宽和高取出前面拼两个0
# -------------------------------------------------------#
gt_box = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
# -------------------------------------------------------#
# 将先验框转换一个形式
# 9, 4 在先验框大小前面加了两个0
# -------------------------------------------------------#
anchor_shapes = torch.FloatTensor(
torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
# -------------------------------------------------------#
# 计算交并比
# self.calculate_iou(gt_box, anchor_shapes) = [num_true_box, 9]每一个真实框和9个先验框的重合情况
# best_ns:
# [每个真实框最大的重合度max_iou, 每一个真实框最重合的先验框的序号] # self.calculate_iou(gt_box, anchor_shapes) 的结果,是 b x len(anchors)
# -------------------------------------------------------#
best_ns = torch.argmax(self.calculate_iou(gt_box, anchor_shapes), dim=-1) # 找到每个真实框与所有anchor的IoU然后取出每个真实框最匹配的anchor下标
# 依次遍历每个真实框对应的anchor号数找到在 所属当前层的3个anchor中的下标
for t, best_n in enumerate(best_ns): # l是最后输出的多层特征图第几层
if best_n not in self.anchors_mask[l]: # self.anchors_mask的用法是指定当前特征图用的是哪3个anchor
continue
# ----------------------------------------#
# 判断这个先验框是当前特征点的哪一个先验框 l是第几号最后的输出特征图
# ----------------------------------------#
k = self.anchors_mask[l].index(best_n) # 使用当前层对应anchors的第几号anchor
# ----------------------------------------#
# 获得真实框属于哪个网格点 获取中心点。因为映射到了13*13分辨率上。 floor不就是左上角的意思
# ----------------------------------------#
i = torch.floor(batch_target[t, 0]).long() # t 表示当前是第几个真实框
j = torch.floor(batch_target[t, 1]).long()
# ----------------------------------------#
# 取出真实框的种类
# ----------------------------------------#
c = batch_target[t, 4].long()
# ----------------------------------------#
# noobj_mask代表无目标的特征点 b是几号batchk是几号anchor
# ----------------------------------------#
noobj_mask[b, k, j, i] = 0
# ----------------------------------------#
# tx、ty代表中心调整参数的真实值
# ----------------------------------------#
if not self.giou: # 不走这条分支
# ----------------------------------------#
# tx、ty代表中心调整参数的真实值
# ----------------------------------------#
y_true[b, k, j, i, 0] = batch_target[t, 0] - i.float()
y_true[b, k, j, i, 1] = batch_target[t, 1] - j.float()
y_true[b, k, j, i, 2] = math.log(batch_target[t, 2] / anchors[best_n][0])
y_true[b, k, j, i, 3] = math.log(batch_target[t, 3] / anchors[best_n][1])
y_true[b, k, j, i, 4] = 1
y_true[b, k, j, i, c + 5] = 1 # 重新设置标记种类
else:
# ----------------------------------------#
# tx、ty代表中心调整参数的真实值  重新生成的标签 y_true t是当前的图像的第t个真实框
# ----------------------------------------#
y_true[b, k, j, i, 0] = batch_target[t, 0]
y_true[b, k, j, i, 1] = batch_target[t, 1]
y_true[b, k, j, i, 2] = batch_target[t, 2]
y_true[b, k, j, i, 3] = batch_target[t, 3]
y_true[b, k, j, i, 4] = 1 # 有物体
y_true[b, k, j, i, c + 5] = 1 # c是种类
# ----------------------------------------#
# 用于获得xywh的比例
# 大目标loss权重小小目标loss权重大
# ----------------------------------------#
box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h # 这里计算出面积能反应大小目标。又归一化到0~1之间。
return y_true, noobj_mask, box_loss_scale
def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
# -----------------------------------------------------#
# 计算一共有多少张图片
# -----------------------------------------------------#
bs = len(targets)
# -----------------------------------------------------#
# 生成网格,先验框中心,网格左上角 torch.linspace(0, in_w - 1, in_w) 在0, in_w - 1之间分成in_w个点。.repeat(in_h, 1)沿0重复in_h次沿1重复1次
# -----------------------------------------------------#
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x) # 这样写 repeat 比较清晰。repeat从右向左分析比较清晰。后两维是沿着竖轴和横轴重复指定次数。
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
# 生成先验框的宽高
scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]] # 取出对应的3个先验框的具体值
anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x) # 沿1维度找到第几维值
anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape) # 13*13 个一样的形成一组。3个不一样的13*13。 x8次
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
# -------------------------------------------------------#
# 计算调整后的先验框中心与宽高 x是输出的第0属性就是x的sigmoid的输出坐标
# -------------------------------------------------------#
pred_boxes_x = torch.unsqueeze(x + grid_x, -1)
pred_boxes_y = torch.unsqueeze(y + grid_y, -1)
pred_boxes_w = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
pred_boxes_h = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim=-1)
for b in range(bs): # 对一个 batch 里的数据 一张张图像 分别进行操作
# -------------------------------------------------------#
# 将预测结果转换一个形式
# pred_boxes_for_ignore num_anchors, 4
# -------------------------------------------------------#
pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
# -------------------------------------------------------#
# 计算真实框,并把真实框转换成相对于特征层的大小
# gt_box num_true_box, 4
# -------------------------------------------------------#
if len(targets[b]) > 0: # 如果有目标,进行下面的操作。否则 跳到下一张图片。
batch_target = torch.zeros_like(targets[b])
# -------------------------------------------------------#
# 计算出正样本在特征层上的中心点 # 这里地方好像也是把 box当前左上角和右下角的形式实现已经变成了中心点与宽高的形式。但无论如何最终的结果没变。
# -------------------------------------------------------#
batch_target[:, [0, 2]] = targets[b][:, [0, 2]] * in_w
batch_target[:, [1, 3]] = targets[b][:, [1, 3]] * in_h
batch_target = batch_target[:, :4].type_as(x)
# -------------------------------------------------------#
# 计算交并比
# anch_ious num_true_box, num_anchors
# -------------------------------------------------------#
anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore) # 真实框与预测框的IoU
# -------------------------------------------------------#
# 每个先验框???对应真实框的最大重合度
# anch_ious_max num_anchors
# -------------------------------------------------------#
anch_ious_max, _ = torch.max(anch_ious, dim=0) # 每个真实框与预测框的最大值。
anch_ious_max = anch_ious_max.view(pred_boxes[b].size()[:3])
noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0 # 如果大于某个阈值即使不是最匹配的也可以忽略这个cell。所以noobj设置为0。
return noobj_mask, pred_boxes
def weights_init(net, init_type='normal', init_gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
if init_type == 'normal':
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
print('initialize network with %s type' % init_type)
net.apply(init_func)
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio=0.05, warmup_lr_ratio=0.1,
no_aug_iter_ratio=0.05, step_num=10):
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
if iters <= warmup_total_iters:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
elif iters >= total_iters - no_aug_iter:
lr = min_lr
else:
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0 + math.cos(
math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
)
return lr
def step_lr(lr, decay_rate, step_size, iters):
if step_size < 1:
raise ValueError("step_size must above 1.")
n = iters // step_size
out_lr = lr * decay_rate ** n
return out_lr
if lr_decay_type == "cos":
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
func = partial(yolox_warm_cos_lr, lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
else:
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
step_size = total_iters / step_num
func = partial(step_lr, lr, decay_rate, step_size)
return func
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
lr = lr_scheduler_func(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr