242 lines
9.5 KiB
Python
242 lines
9.5 KiB
Python
import datetime
|
||
import os
|
||
|
||
import torch
|
||
import matplotlib
|
||
|
||
import scipy.signal
|
||
from matplotlib import pyplot as plt
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
import shutil
|
||
import numpy as np
|
||
|
||
from PIL import Image
|
||
from tqdm import tqdm
|
||
from .utils import cvtColor, preprocess_input, resize_image
|
||
from .utils_bbox import DecodeBox
|
||
from .utils_map import get_coco_map, get_map
|
||
|
||
matplotlib.use('Agg')
|
||
|
||
|
||
class LossHistory():
|
||
def __init__(self, log_dir, model, input_shape):
|
||
self.log_dir = log_dir
|
||
self.losses = []
|
||
self.val_loss = []
|
||
|
||
os.makedirs(self.log_dir)
|
||
self.writer = SummaryWriter(self.log_dir)
|
||
try:
|
||
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
|
||
self.writer.add_graph(model, dummy_input)
|
||
except:
|
||
pass
|
||
|
||
def append_loss(self, epoch, loss, val_loss):
|
||
if not os.path.exists(self.log_dir):
|
||
os.makedirs(self.log_dir)
|
||
|
||
self.losses.append(loss)
|
||
self.val_loss.append(val_loss)
|
||
|
||
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
|
||
f.write(str(loss))
|
||
f.write("\n")
|
||
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
|
||
f.write(str(val_loss))
|
||
f.write("\n")
|
||
|
||
self.writer.add_scalar('loss', loss, epoch)
|
||
self.writer.add_scalar('val_loss', val_loss, epoch)
|
||
self.loss_plot()
|
||
|
||
def loss_plot(self):
|
||
iters = range(len(self.losses))
|
||
|
||
plt.figure()
|
||
plt.plot(iters, self.losses, 'red', linewidth=2, label='train loss')
|
||
plt.plot(iters, self.val_loss, 'coral', linewidth=2, label='val loss')
|
||
try:
|
||
if len(self.losses) < 25:
|
||
num = 5
|
||
else:
|
||
num = 15
|
||
|
||
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle='--', linewidth=2,
|
||
label='smooth train loss')
|
||
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle='--', linewidth=2,
|
||
label='smooth val loss')
|
||
except:
|
||
pass
|
||
|
||
plt.grid(True)
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel('Loss')
|
||
plt.legend(loc="upper right")
|
||
|
||
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
|
||
|
||
plt.cla()
|
||
plt.close("all")
|
||
|
||
|
||
class EvalCallback():
|
||
def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
|
||
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True,
|
||
MINOVERLAP=0.5, eval_flag=True, period=1):
|
||
super(EvalCallback, self).__init__()
|
||
|
||
self.net = net
|
||
self.input_shape = input_shape
|
||
self.anchors = anchors
|
||
self.anchors_mask = anchors_mask
|
||
self.class_names = class_names
|
||
self.num_classes = num_classes
|
||
self.val_lines = val_lines
|
||
self.log_dir = log_dir
|
||
self.cuda = cuda
|
||
self.map_out_path = map_out_path
|
||
self.max_boxes = max_boxes
|
||
self.confidence = confidence
|
||
self.nms_iou = nms_iou
|
||
self.letterbox_image = letterbox_image
|
||
self.MINOVERLAP = MINOVERLAP
|
||
self.eval_flag = eval_flag
|
||
self.period = period
|
||
|
||
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]),
|
||
self.anchors_mask)
|
||
|
||
self.maps = [0]
|
||
self.epoches = [0]
|
||
if self.eval_flag:
|
||
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
||
f.write(str(0))
|
||
f.write("\n")
|
||
|
||
def get_map_txt(self, image_id, image, class_names, map_out_path):
|
||
f = open(os.path.join(map_out_path, "detection-results/" + image_id + ".txt"), "w", encoding='utf-8')
|
||
image_shape = np.array(np.shape(image)[0:2])
|
||
# ---------------------------------------------------------#
|
||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||
# ---------------------------------------------------------#
|
||
image = cvtColor(image)
|
||
# ---------------------------------------------------------#
|
||
# 给图像增加灰条,实现不失真的resize
|
||
# 也可以直接resize进行识别
|
||
# ---------------------------------------------------------#
|
||
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
||
# ---------------------------------------------------------#
|
||
# 添加上batch_size维度
|
||
# ---------------------------------------------------------#
|
||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||
|
||
with torch.no_grad():
|
||
images = torch.from_numpy(image_data)
|
||
if self.cuda:
|
||
images = images.cuda()
|
||
# ---------------------------------------------------------#
|
||
# 将图像输入网络当中进行预测!
|
||
# ---------------------------------------------------------#
|
||
outputs = self.net(images)
|
||
outputs = self.bbox_util.decode_box(outputs)
|
||
# ---------------------------------------------------------#
|
||
# 将预测框进行堆叠,然后进行非极大抑制
|
||
# ---------------------------------------------------------#
|
||
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
||
image_shape, self.letterbox_image, conf_thres=self.confidence,
|
||
nms_thres=self.nms_iou)
|
||
|
||
if results[0] is None:
|
||
return
|
||
|
||
top_label = np.array(results[0][:, 6], dtype='int32')
|
||
top_conf = results[0][:, 4] * results[0][:, 5]
|
||
top_boxes = results[0][:, :4]
|
||
|
||
top_100 = np.argsort(top_label)[::-1][:self.max_boxes]
|
||
top_boxes = top_boxes[top_100]
|
||
top_conf = top_conf[top_100]
|
||
top_label = top_label[top_100]
|
||
|
||
for i, c in list(enumerate(top_label)):
|
||
predicted_class = self.class_names[int(c)]
|
||
box = top_boxes[i]
|
||
score = str(top_conf[i])
|
||
|
||
top, left, bottom, right = box
|
||
if predicted_class not in class_names:
|
||
continue
|
||
|
||
f.write("%s %s %s %s %s %s\n" % (
|
||
predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom))))
|
||
|
||
f.close()
|
||
return
|
||
|
||
def on_epoch_end(self, epoch, model_eval):
|
||
if epoch % self.period == 0 and self.eval_flag:
|
||
self.net = model_eval
|
||
if not os.path.exists(self.map_out_path):
|
||
os.makedirs(self.map_out_path)
|
||
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
|
||
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
|
||
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
|
||
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
|
||
print("Get map.")
|
||
for annotation_line in tqdm(self.val_lines):
|
||
line = annotation_line.split()
|
||
image_id = os.path.basename(line[0]).split('.')[0]
|
||
# ------------------------------#
|
||
# 读取图像并转换成RGB图像
|
||
# ------------------------------#
|
||
image = Image.open(line[0])
|
||
# ------------------------------#
|
||
# 获得预测框
|
||
# ------------------------------#
|
||
gt_boxes = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
|
||
# ------------------------------#
|
||
# 获得预测txt
|
||
# ------------------------------#
|
||
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
|
||
|
||
# ------------------------------#
|
||
# 获得真实框txt
|
||
# ------------------------------#
|
||
with open(os.path.join(self.map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f:
|
||
for box in gt_boxes:
|
||
left, top, right, bottom, obj = box
|
||
obj_name = self.class_names[obj]
|
||
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
||
|
||
print("Calculate Map.")
|
||
try:
|
||
temp_map = get_coco_map(class_names=self.class_names, path=self.map_out_path)[1]
|
||
except:
|
||
temp_map = get_map(self.MINOVERLAP, False, path=self.map_out_path)
|
||
self.maps.append(temp_map)
|
||
self.epoches.append(epoch)
|
||
|
||
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
||
f.write(str(temp_map))
|
||
f.write("\n")
|
||
|
||
plt.figure()
|
||
plt.plot(self.epoches, self.maps, 'red', linewidth=2, label='train map')
|
||
|
||
plt.grid(True)
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel('Map %s' % str(self.MINOVERLAP))
|
||
plt.title('A Map Curve')
|
||
plt.legend(loc="upper right")
|
||
|
||
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
|
||
plt.cla()
|
||
plt.close("all")
|
||
|
||
print("Get map done.")
|
||
shutil.rmtree(self.map_out_path)
|