80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
|
import numpy as np
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
|
|||
|
# ---------------------------------------------------------#
|
|||
|
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
|||
|
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
|||
|
# ---------------------------------------------------------#
|
|||
|
def cvtColor(image):
|
|||
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
|||
|
return image
|
|||
|
else:
|
|||
|
image = image.convert('RGB')
|
|||
|
return image
|
|||
|
|
|||
|
# ---------------------------------------------------#
|
|||
|
|
|||
|
|
|||
|
# 对输入图像进行resize
|
|||
|
# ---------------------------------------------------#
|
|||
|
def resize_image(image, size, letterbox_image):
|
|||
|
iw, ih = image.size
|
|||
|
w, h = size
|
|||
|
if letterbox_image:
|
|||
|
scale = min(w / iw, h / ih)
|
|||
|
nw = int(iw * scale)
|
|||
|
nh = int(ih * scale)
|
|||
|
|
|||
|
image = image.resize((nw, nh), Image.BICUBIC)
|
|||
|
new_image = Image.new('RGB', size, (128, 128, 128))
|
|||
|
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
|
|||
|
else:
|
|||
|
new_image = image.resize((w, h), Image.BICUBIC) # 这里直接用了缩放,而不是加灰条的形式
|
|||
|
return new_image
|
|||
|
|
|||
|
|
|||
|
# ---------------------------------------------------#
|
|||
|
# 获得类
|
|||
|
# ---------------------------------------------------#
|
|||
|
def get_classes(classes_path):
|
|||
|
with open(classes_path, encoding='utf-8') as f:
|
|||
|
class_names = f.readlines()
|
|||
|
class_names = [c.strip() for c in class_names]
|
|||
|
return class_names, len(class_names)
|
|||
|
|
|||
|
|
|||
|
# ---------------------------------------------------#
|
|||
|
# 获得先验框
|
|||
|
# ---------------------------------------------------#
|
|||
|
def get_anchors(anchors_path):
|
|||
|
'''loads the anchors from a file'''
|
|||
|
with open(anchors_path, encoding='utf-8') as f:
|
|||
|
anchors = f.readline()
|
|||
|
anchors = [float(x) for x in anchors.split(',')]
|
|||
|
anchors = np.array(anchors).reshape(-1, 2)
|
|||
|
return anchors, len(anchors)
|
|||
|
|
|||
|
|
|||
|
# ---------------------------------------------------#
|
|||
|
# 获得学习率
|
|||
|
# ---------------------------------------------------#
|
|||
|
def get_lr(optimizer):
|
|||
|
for param_group in optimizer.param_groups:
|
|||
|
return param_group['lr']
|
|||
|
|
|||
|
|
|||
|
def preprocess_input(image):
|
|||
|
image /= 255.0
|
|||
|
return image
|
|||
|
|
|||
|
|
|||
|
def show_config(**kwargs):
|
|||
|
print('Configurations:')
|
|||
|
print('-' * 70)
|
|||
|
print('|%25s | %40s|' % ('keys', 'values'))
|
|||
|
print('-' * 70)
|
|||
|
for key, value in kwargs.items():
|
|||
|
print('|%25s | %40s|' % (str(key), str(value)))
|
|||
|
print('-' * 70)
|