fujie_code/summary.py

35 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# --------------------------------------------#
# 该部分代码用于看网络结构
# --------------------------------------------#
import torch
# from thop import clever_format, profile
from torchsummary import summary
from nets.yolo import YoloBody
if __name__ == "__main__":
input_shape = [416, 416]
anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
num_classes = 80
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
m = YoloBody(anchors_mask, num_classes)
print(m)
print('-' * 80)
m = m.to(device)
summary(m, (3, input_shape[0], input_shape[1]))
# dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
# flops, params = profile(m.to(device), (dummy_input,), verbose=False)
# --------------------------------------------------------#
# flops * 2是因为profile没有将卷积作为两个operations
# 有些论文将卷积算乘法、加法两个operations。此时乘2
# 有些论文只考虑乘法的运算次数忽略加法。此时不乘2
# 本代码选择乘2参考YOLOX。
# --------------------------------------------------------#
# flops = flops * 2
# flops, params = clever_format([flops, params], "%.3f")
# print('Total GFLOPS: %s' % (flops))
# print('Total params: %s' % (params))