代码迁移
commit
ff11a09b8b
|
@ -0,0 +1,89 @@
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms, models
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class AgeGenderPredictor:
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.model = self.load_model(model_path)
|
||||||
|
self.gender_labels=['Female','Male']
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
||||||
|
num_ftrs = model.fc.in_features
|
||||||
|
model.fc = nn.Linear(num_ftrs, 3) # 输出为性别和年龄
|
||||||
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
preprocess = transforms.Compose([
|
||||||
|
transforms.Resize(256),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||||
|
input_tensor = preprocess(image)
|
||||||
|
input_batch = input_tensor.unsqueeze(0)
|
||||||
|
return input_batch
|
||||||
|
|
||||||
|
def predict(self, face):
|
||||||
|
input_batch = self.preprocess_image(face)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
input_batch = input_batch.to('cuda')
|
||||||
|
self.model.to('cuda')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(input_batch)
|
||||||
|
gender_preds = output[:, :2]
|
||||||
|
age_preds = output[:, -1]
|
||||||
|
gender = gender_preds.argmax(dim=1).item()
|
||||||
|
age = age_preds.item()
|
||||||
|
return self.gender_labels[gender], age, self.age_group(age)
|
||||||
|
|
||||||
|
def age_group(self, age):
|
||||||
|
if age <= 18:
|
||||||
|
return 'Teenager'
|
||||||
|
elif age <= 59:
|
||||||
|
return 'Adult'
|
||||||
|
else:
|
||||||
|
return 'Senior'
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 创建 AgeGenderPredictor 类的实例
|
||||||
|
predictor = AgeGenderPredictor('megaage_model_epoch99.pth')
|
||||||
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||||
|
# 打开摄像头
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 读取一帧
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 进行人脸检测
|
||||||
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
||||||
|
|
||||||
|
# 对于检测到的每一个人脸
|
||||||
|
for (x, y, w, h) in faces:
|
||||||
|
# 提取人脸 ROI
|
||||||
|
face = frame[y:y + h, x:x + w]
|
||||||
|
gender, age, age_group = predictor.predict(face)
|
||||||
|
|
||||||
|
cv2.putText(frame, f'Gender: {gender}, Age: {int(age)}, Age Group: {age_group}', (x, y - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
|
||||||
|
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# 显示帧
|
||||||
|
cv2.imshow('Webcam', frame)
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
# 释放摄像头并关闭所有窗口
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
|
@ -0,0 +1,32 @@
|
||||||
|
# 基于视觉的年龄性别预测系统
|
||||||
|
|
||||||
|
该项目是一个基于图像的年龄和性别预测系统。它使用ResNet50模型在MegaAge-Asian数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测年龄、性别和年龄组。
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
- `AgeGenderPredictor.py`: 包含年龄性别预测模型的加载、预处理和推理逻辑。
|
||||||
|
- `megaage_model_epoch99.pth`: 在MegaAge-Asian数据集上训练的模型权重文件。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`和`Pillow`。
|
||||||
|
2. 运行`AgeGenderPredictor.py`脚本。
|
||||||
|
3. 脚本将打开默认摄像头,开始人脸检测和年龄性别预测。
|
||||||
|
4. 检测到的人脸周围会用矩形框标注,并显示预测的性别、年龄和年龄组信息。
|
||||||
|
5. 按`q`键退出程序。
|
||||||
|
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
该项目使用ResNet50作为基础模型,对MegaAge-Asian数据集进行训练,以预测人脸图像的年龄和性别。最终模型输出包含3个值,分别对应男性概率、女性概率和估计年龄值。
|
||||||
|
|
||||||
|
### MegaAge-Asian数据集
|
||||||
|
|
||||||
|
MegaAge-Asian是一个大规模的人脸图像数据集,由商汤发布,总数有40000张图像。数据集中的图像包含了不同年龄和性别的亚洲人脸,年龄范围从1岁到70岁。
|
||||||
|
|
||||||
|
## 算法流程
|
||||||
|
|
||||||
|
1. **人脸检测**: 使用OpenCV内置的Haar级联人脸检测器在视频帧中检测人脸。
|
||||||
|
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||||
|
3. **推理**: 将预处理后的图像输入到预训练的ResNet50模型中,获得性别概率和年龄值的预测结果。
|
||||||
|
4. **后处理**: 根据性别概率确定性别标签,将年龄值映射到具体的年龄组。
|
||||||
|
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的性别、年龄和年龄组信息。
|
|
@ -0,0 +1,149 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from scipy.signal import butter, filtfilt
|
||||||
|
import pywt
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
|
||||||
|
class BPModel:
|
||||||
|
def __init__(self, model_path, fps=30):
|
||||||
|
self.fps = fps
|
||||||
|
|
||||||
|
self.model = LSTMModel()
|
||||||
|
|
||||||
|
self.load_model(model_path)
|
||||||
|
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
self.warmup()
|
||||||
|
|
||||||
|
def predict(self, frames):
|
||||||
|
yg, g, t = self.process_frame_sequence(frames, self.fps)
|
||||||
|
yg = yg.reshape(1, -1, 1)
|
||||||
|
inputs = torch.tensor(yg.copy(), dtype=torch.float32)
|
||||||
|
inputs = inputs.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
sbp_outputs, dbp_outputs = self.model(inputs)
|
||||||
|
sbp_outputs = sbp_outputs.cpu().detach().numpy().item()
|
||||||
|
dbp_outputs = dbp_outputs.cpu().detach().numpy().item()
|
||||||
|
return sbp_outputs, dbp_outputs
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
|
||||||
|
|
||||||
|
model_state_dict = torch.load(model_path)
|
||||||
|
|
||||||
|
#判断model_state_dict的类型是否是OrderedDict
|
||||||
|
if not isinstance(model_state_dict, OrderedDict):
|
||||||
|
# model_state_dict=model_state_dict.state_dict()
|
||||||
|
#若不是OrderedDict类型,则为LSMTModel类型,直接加载
|
||||||
|
self.model = model_state_dict
|
||||||
|
return
|
||||||
|
|
||||||
|
#判断是否是多GPU训练的模型
|
||||||
|
if 'module' in model_state_dict.keys():
|
||||||
|
self.model.load_state_dict(model_state_dict['module'])
|
||||||
|
else:
|
||||||
|
#遍历模型参数,判断参数前是否有module.
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in model_state_dict.items():
|
||||||
|
if 'module.' in k:
|
||||||
|
name = k[7:]
|
||||||
|
else:
|
||||||
|
name = k
|
||||||
|
new_state_dict[name] = v
|
||||||
|
self.model.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 模型预热
|
||||||
|
def warmup(self):
|
||||||
|
inputs = torch.randn(10, 250, 1)
|
||||||
|
inputs = inputs.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model(inputs)
|
||||||
|
|
||||||
|
def wavelet_detrend(self, signal, wavelet='sym6', level=6):
|
||||||
|
"""
|
||||||
|
小波分解和基线漂移去除
|
||||||
|
|
||||||
|
参数:
|
||||||
|
signal (numpy.ndarray): 输入信号
|
||||||
|
wavelet (str): 小波基函数名称,默认为'sym6'
|
||||||
|
level (int): 小波分解层数,默认为6
|
||||||
|
|
||||||
|
返回:
|
||||||
|
detrended_signal (numpy.ndarray): 去除基线漂移后的信号
|
||||||
|
"""
|
||||||
|
# 执行小波分解
|
||||||
|
coeffs = pywt.wavedec(signal, wavelet, level=level)
|
||||||
|
|
||||||
|
# 获取第六层近似分量(基线漂移)
|
||||||
|
cA6 = coeffs[0]
|
||||||
|
|
||||||
|
# 重构信号,去除基线漂移
|
||||||
|
coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零
|
||||||
|
detrended_signal = pywt.waverec(coeffs, wavelet)
|
||||||
|
|
||||||
|
return detrended_signal
|
||||||
|
|
||||||
|
def butter_bandpass(self, lowcut, highcut, fs, order=5):
|
||||||
|
nyq = 0.5 * fs
|
||||||
|
low = lowcut / nyq
|
||||||
|
high = highcut / nyq
|
||||||
|
b, a = butter(order, [low, high], btype='band')
|
||||||
|
return b, a
|
||||||
|
|
||||||
|
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
|
||||||
|
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
|
||||||
|
y = filtfilt(b, a, data)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def process_frame_sequence(self, frames, fps):
|
||||||
|
"""
|
||||||
|
处理帧序列
|
||||||
|
|
||||||
|
参数:
|
||||||
|
frames (list): 包含所有帧的列表,每一帧为numpy.ndarray
|
||||||
|
|
||||||
|
返回:
|
||||||
|
t (list): 时间序列(秒),从0开始
|
||||||
|
yg (numpy.ndarray): 处理后的绿色通道数据
|
||||||
|
green (numpy.ndarray): 原始绿色通道数据
|
||||||
|
"""
|
||||||
|
all_frames = frames
|
||||||
|
|
||||||
|
green = []
|
||||||
|
for frame in all_frames:
|
||||||
|
r, g, b = (frame.mean(axis=0)).mean(axis=0)
|
||||||
|
green.append(g)
|
||||||
|
|
||||||
|
t = [i / fps for i in range(len(all_frames))]
|
||||||
|
|
||||||
|
g_detrended = self.wavelet_detrend(green)
|
||||||
|
lowcut = 0.6
|
||||||
|
highcut = 8
|
||||||
|
datag = g_detrended
|
||||||
|
yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4)
|
||||||
|
|
||||||
|
# self.plot(green, t, 'Original Green Channel',color='green')
|
||||||
|
# self.plot(g_detrended, t, 'Detrended Green Channel', color='red')
|
||||||
|
# self.plot(yg, t, 'Filtered Green Channel',color='blue')
|
||||||
|
|
||||||
|
|
||||||
|
return yg, green, t
|
||||||
|
|
||||||
|
def plot(self, yg, t,title,color='green',figsize=(30, 10)):
|
||||||
|
plt.figure(figsize=figsize)
|
||||||
|
plt.plot(t, yg, label=title, color=color)
|
||||||
|
plt.xlabel('Time (s)')
|
||||||
|
plt.ylabel('Amplitude')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# 基于rPPG的血压估计系统
|
||||||
|
|
||||||
|
该项目是一个基于远程光电容积脉搏波描记法(rPPG)的血压估计系统。它使用LSTM神经网络在多个rPPG数据集上进行训练,然后可以从视频中提取光学脉冲信号,并预测个体的收缩压(SBP)和舒张压(DBP)值。
|
||||||
|
|
||||||
|
## 核心文件
|
||||||
|
|
||||||
|
- `BPApi.py`: 包含BP估计模型的核心逻辑,如信号预处理、模型推理等。
|
||||||
|
- `lstm.py`: 定义了用于BP估计的LSTM神经网络架构。
|
||||||
|
- `video.py`: 视频处理、人脸检测和BP估计的主要脚本。
|
||||||
|
- `best_model.pth`: 在多个数据集上训练的最佳模型权重文件。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`numpy`、`scipy`和`pywavelets`。
|
||||||
|
2. 运行`video.py`脚本。
|
||||||
|
3. 脚本将打开默认摄像头,开始人脸检测和BP估计。
|
||||||
|
4. 检测到的人脸区域将被提取用于BP估计,预测结果将显示在视频流窗口中。
|
||||||
|
5. 按`q`键退出程序。
|
||||||
|
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
该项目使用LSTM神经网络作为基础模型,使用大规模PPG信号数据集进行预训练,并进一步使用rPPG信号数据集进行微调,以预测个体的SBP和DBP值。模型输出包含两个值,分别对应SBP和DBP的预测值。
|
||||||
|
|
||||||
|
### 数据集介绍
|
||||||
|
|
||||||
|
该项目使用了以下三个公开的rPPG数据集进行训练:
|
||||||
|
|
||||||
|
1. **MIMIC-III数据集**: 包含9054000条PPG信号序列和对应的SBP/DBP标签。
|
||||||
|
2. **UKL-rPPG数据集**: 包含7851条rPPG信号序列和对应的SBP/DBP标签。
|
||||||
|
3. **iPPG-BP数据集**: 包含2120条rPPG信号序列和对应的SBP/DBP标签。
|
||||||
|
|
||||||
|
## 算法流程
|
||||||
|
|
||||||
|
1. **视频采集**:
|
||||||
|
- 使用OpenCV库初始化视频捕捉对象,并获取视频的帧率。
|
||||||
|
|
||||||
|
2. **人脸检测**:
|
||||||
|
- 在每一帧上使用Haar级联人脸检测器进行人脸检测。
|
||||||
|
- 如果检测到人脸,获取人脸区域的边界框坐标。
|
||||||
|
|
||||||
|
3. **帧序列提取**:
|
||||||
|
- 维护一个固定长度(如250帧)的循环队列,用于存储最近的人脸帧序列。
|
||||||
|
- 对于新检测到的人脸,将其添加到队列中。
|
||||||
|
|
||||||
|
4. **信号预处理**:
|
||||||
|
- 当队列满时,执行以下预处理步骤:
|
||||||
|
- 从人脸帧序列中提取绿色通道信号。
|
||||||
|
- 使用小波变换进行去趋势,消除基线漂移。
|
||||||
|
- 使用带通滤波器去除高频和低频噪声,保留有效的脉搏频率范围。
|
||||||
|
|
||||||
|
5. **推理**:
|
||||||
|
- 将预处理后的绿色通道信号输入到LSTM神经网络模型中。
|
||||||
|
- 模型输出SBP和DBP的预测值。
|
||||||
|
|
||||||
|
6. **可视化**:
|
||||||
|
- 在视频帧上绘制人脸边界框。
|
||||||
|
- 在视频帧上显示预测的SBP和DBP值。
|
||||||
|
|
||||||
|
7. **持续循环**:
|
||||||
|
- 对新的视频帧重复执行步骤2-6,持续进行人脸检测、BP估计和可视化。
|
||||||
|
|
||||||
|
8. **退出**:
|
||||||
|
- 当用户按下特定按键(如'q')时,退出程序,关闭视频捕捉对象和所有窗口。
|
|
@ -0,0 +1,285 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import h5py
|
||||||
|
|
||||||
|
def custom_collate_fn(batch):
|
||||||
|
X, y_SBP, y_DBP = zip(*batch)
|
||||||
|
|
||||||
|
X = torch.tensor(np.array(X), dtype=torch.float32)
|
||||||
|
y_SBP = torch.tensor(y_SBP, dtype=torch.float32)
|
||||||
|
y_DBP = torch.tensor(y_DBP, dtype=torch.float32)
|
||||||
|
|
||||||
|
return X, y_SBP, y_DBP
|
||||||
|
|
||||||
|
|
||||||
|
class BPDataset(Dataset):
|
||||||
|
def __init__(self, X_data, y_SBP, y_DBP):
|
||||||
|
self.X_data = X_data
|
||||||
|
self.y_SBP = y_SBP
|
||||||
|
self.y_DBP = y_DBP
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.y_SBP)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# X_sample = self.X_data[idx * 250:(idx + 1) * 250]
|
||||||
|
X_sample = self.X_data[idx]
|
||||||
|
y_SBP_sample = self.y_SBP[idx]
|
||||||
|
y_DBP_sample = self.y_DBP[idx]
|
||||||
|
|
||||||
|
return X_sample, y_SBP_sample, y_DBP_sample
|
||||||
|
|
||||||
|
|
||||||
|
class BPDataLoader:
|
||||||
|
def __init__(self, data_dir, val_split=0.2, batch_size=32, shuffle=True, data_type='npy'):
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.val_split = val_split
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.train_dataloader = None
|
||||||
|
self.val_dataloader = None
|
||||||
|
self.data_type = data_type
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
X_BP_path = os.path.join(self.data_dir, 'X_BP.npy')
|
||||||
|
y_DBP_path = os.path.join(self.data_dir, 'Y_DBP.npy')
|
||||||
|
y_SBP_path = os.path.join(self.data_dir, 'Y_SBP.npy')
|
||||||
|
|
||||||
|
X_BP = np.load(X_BP_path)
|
||||||
|
# 将数据reshape成(batch_size, 250,1)的形状
|
||||||
|
X_BP = X_BP.reshape(-1, 250, 1)
|
||||||
|
|
||||||
|
y_DBP = np.load(y_DBP_path)
|
||||||
|
y_SBP = np.load(y_SBP_path)
|
||||||
|
|
||||||
|
return X_BP, y_DBP, y_SBP
|
||||||
|
|
||||||
|
def load_data_UKL_h5(self):
|
||||||
|
|
||||||
|
X_BP_path = os.path.join(self.data_dir, 'rPPG-BP-UKL_rppg_7s.h5')
|
||||||
|
with h5py.File(X_BP_path, 'r') as f:
|
||||||
|
rppg = f.get('rppg')
|
||||||
|
BP = f.get('label')
|
||||||
|
rppg = np.array(rppg)
|
||||||
|
BP = np.array(BP)
|
||||||
|
|
||||||
|
# 将数据从(875, 7851)reshape成(7851, 875, 1)的形状
|
||||||
|
rppg = rppg.transpose(1, 0)
|
||||||
|
rppg = rppg.reshape(-1, 875, 1)
|
||||||
|
|
||||||
|
X_BP = rppg
|
||||||
|
y_DBP = BP[1]
|
||||||
|
y_SBP = BP[0]
|
||||||
|
|
||||||
|
return X_BP, y_DBP, y_SBP
|
||||||
|
|
||||||
|
def load_data_MIMIC_h5(self):
|
||||||
|
|
||||||
|
X_BP_path = os.path.join(self.data_dir, 'MIMIC-III_ppg_dataset.h5')
|
||||||
|
|
||||||
|
#
|
||||||
|
# 获取data_dir下文件列表
|
||||||
|
files = os.listdir(self.data_dir)
|
||||||
|
|
||||||
|
# 检查是否存在已经处理好的数据
|
||||||
|
if 'X_MIMIC_BP.npy' in files and 'Y_MIMIC_DBP.npy' in files and 'Y_MIMIC_SBP.npy' in files:
|
||||||
|
print('loading preprocessed data.....')
|
||||||
|
|
||||||
|
X_BP = np.load(os.path.join(self.data_dir, 'X_MIMIC_BP.npy'))
|
||||||
|
y_DBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_DBP.npy'))
|
||||||
|
y_SBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_SBP.npy'))
|
||||||
|
|
||||||
|
return X_BP, y_DBP, y_SBP
|
||||||
|
|
||||||
|
with h5py.File(X_BP_path, 'r') as f:
|
||||||
|
ppg = f.get('ppg')
|
||||||
|
BP = f.get('label')
|
||||||
|
ppg = np.array(ppg)
|
||||||
|
BP = np.array(BP)
|
||||||
|
|
||||||
|
# 统计BP中SBP的最大值和最小值
|
||||||
|
max_sbp = np.max(BP[:, 0])
|
||||||
|
min_sbp = np.min(BP[:, 0])
|
||||||
|
|
||||||
|
max_sbp = 10 - max_sbp % 10 + max_sbp
|
||||||
|
min_sbp = min_sbp - min_sbp % 10
|
||||||
|
|
||||||
|
# 划分区间
|
||||||
|
bins = np.arange(min_sbp, max_sbp, 10)
|
||||||
|
|
||||||
|
print(bins)
|
||||||
|
|
||||||
|
sampled_ppg_data = []
|
||||||
|
sampled_bp_data = []
|
||||||
|
|
||||||
|
for i in range(len(bins) - 1):
|
||||||
|
# 获取当前区间的数据
|
||||||
|
bin_data_sbp_dbp = BP[(BP[:, 0] >= bins[i]) & (BP[:, 0] < bins[i + 1])]
|
||||||
|
bin_data_ppg = ppg[(BP[:, 0] >= bins[i]) & (BP[:, 0] < bins[i + 1])]
|
||||||
|
|
||||||
|
# 如果当前区间有数据
|
||||||
|
if len(bin_data_sbp_dbp) > 0:
|
||||||
|
# 从当前区间中随机抽取20%的数据
|
||||||
|
num_samples = int(len(bin_data_sbp_dbp) * 0.1)
|
||||||
|
indices = np.random.choice(len(bin_data_sbp_dbp), num_samples, replace=False)
|
||||||
|
sampled_bin_data_sbp_dbp = bin_data_sbp_dbp[indices]
|
||||||
|
sampled_bin_data_ppg = bin_data_ppg[indices]
|
||||||
|
|
||||||
|
# 将抽取的数据添加到最终的列表中
|
||||||
|
sampled_bp_data.append(sampled_bin_data_sbp_dbp)
|
||||||
|
sampled_ppg_data.append(sampled_bin_data_ppg)
|
||||||
|
|
||||||
|
# 将列表中的数据合并成NumPy数组
|
||||||
|
ppg = np.concatenate(sampled_ppg_data, axis=0)
|
||||||
|
BP = np.concatenate(sampled_bp_data, axis=0)
|
||||||
|
|
||||||
|
print(ppg.shape, BP.shape)
|
||||||
|
|
||||||
|
# 将数据从(9054000, 875)reshape成(9054000, 875, 1)的形状
|
||||||
|
ppg = ppg.reshape(-1, 875, 1)
|
||||||
|
|
||||||
|
X_BP = ppg
|
||||||
|
|
||||||
|
# 取出第一列赋值给y_DBP,第0列赋值给y_SBP
|
||||||
|
y_DBP = BP[:, 1]
|
||||||
|
y_SBP = BP[:, 0]
|
||||||
|
|
||||||
|
# 将数据保存到文件中
|
||||||
|
np.save('data/X_MIMIC_BP.npy', X_BP)
|
||||||
|
np.save('data/Y_MIMIC_DBP.npy', y_DBP)
|
||||||
|
np.save('data/Y_MIMIC_SBP.npy', y_SBP)
|
||||||
|
|
||||||
|
return X_BP, y_DBP, y_SBP
|
||||||
|
|
||||||
|
def load_data_MIMIC_h5_full(self):
|
||||||
|
|
||||||
|
X_BP_path = os.path.join(self.data_dir, 'MIMIC-III_ppg_dataset.h5')
|
||||||
|
|
||||||
|
# 获取data_dir下文件列表
|
||||||
|
files = os.listdir(self.data_dir)
|
||||||
|
|
||||||
|
# 检查是否存在已经处理好的数据
|
||||||
|
if 'X_MIMIC_BP_full.npy' in files and 'Y_MIMIC_DBP_full.npy' in files and 'Y_MIMIC_SBP_full.npy' in files:
|
||||||
|
print('loading preprocessed data.....')
|
||||||
|
|
||||||
|
X_BP = np.load(os.path.join(self.data_dir, 'X_MIMIC_BP_full.npy'))
|
||||||
|
y_DBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_DBP_full.npy'))
|
||||||
|
y_SBP = np.load(os.path.join(self.data_dir, 'Y_MIMIC_SBP_full.npy'))
|
||||||
|
|
||||||
|
return X_BP, y_DBP, y_SBP
|
||||||
|
|
||||||
|
with h5py.File(X_BP_path, 'r') as f:
|
||||||
|
ppg = f.get('ppg')
|
||||||
|
BP = f.get('label')
|
||||||
|
ppg = np.array(ppg)
|
||||||
|
BP = np.array(BP)
|
||||||
|
|
||||||
|
|
||||||
|
# 将数据从(9054000, 875)reshape成(9054000, 875, 1)的形状
|
||||||
|
ppg = ppg.reshape(-1, 875, 1)
|
||||||
|
|
||||||
|
X_BP = ppg
|
||||||
|
|
||||||
|
# 取出第一列赋值给y_DBP,第0列赋值给y_SBP
|
||||||
|
y_DBP = BP[:, 1]
|
||||||
|
y_SBP = BP[:, 0]
|
||||||
|
|
||||||
|
print("data shape:", X_BP.shape, y_DBP.shape, y_SBP.shape)
|
||||||
|
|
||||||
|
print("saving data.....")
|
||||||
|
|
||||||
|
# 将数据保存到文件中
|
||||||
|
np.save('data/X_MIMIC_BP_full.npy', X_BP)
|
||||||
|
np.save('data/Y_MIMIC_DBP_full.npy', y_DBP)
|
||||||
|
np.save('data/Y_MIMIC_SBP_full.npy', y_SBP)
|
||||||
|
|
||||||
|
print("data saved.....")
|
||||||
|
|
||||||
|
return X_BP, y_DBP, y_SBP
|
||||||
|
|
||||||
|
def create_dataset(self, X_data, y_SBP, y_DBP):
|
||||||
|
return BPDataset(X_data, y_SBP, y_DBP)
|
||||||
|
|
||||||
|
def split_data(self, X_data, y_SBP, y_DBP):
|
||||||
|
X_train, X_val, y_train_SBP, y_val_SBP, y_train_DBP, y_val_DBP = train_test_split(
|
||||||
|
X_data, y_SBP, y_DBP, test_size=self.val_split, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(X_train.shape, X_val.shape, y_train_SBP.shape, y_val_SBP.shape, y_train_DBP.shape, y_val_DBP.shape)
|
||||||
|
|
||||||
|
train_dataset = self.create_dataset(X_train, y_train_SBP, y_train_DBP)
|
||||||
|
val_dataset = self.create_dataset(X_val, y_val_SBP, y_val_DBP)
|
||||||
|
|
||||||
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
def create_dataloaders(self):
|
||||||
|
if self.data_type == 'UKL':
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data_UKL_h5()
|
||||||
|
elif self.data_type == 'MIMIC':
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5()
|
||||||
|
elif self.data_type == 'MIMIC_full':
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5_full()
|
||||||
|
else:
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data()
|
||||||
|
train_dataset, val_dataset = self.split_data(X_data, y_SBP, y_DBP)
|
||||||
|
|
||||||
|
self.train_dataloader = DataLoader(
|
||||||
|
train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, collate_fn=custom_collate_fn
|
||||||
|
)
|
||||||
|
self.val_dataloader = DataLoader(
|
||||||
|
val_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_dataloaders(self):
|
||||||
|
if self.train_dataloader is None or self.val_dataloader is None:
|
||||||
|
self.create_dataloaders()
|
||||||
|
|
||||||
|
return self.train_dataloader, self.val_dataloader
|
||||||
|
|
||||||
|
def get_distributed_dataloaders(self, world_size, rank):
|
||||||
|
|
||||||
|
if self.data_type == 'UKL':
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data_UKL_h5()
|
||||||
|
elif self.data_type == 'MIMIC':
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5()
|
||||||
|
elif self.data_type == 'MIMIC_full':
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data_MIMIC_h5_full()
|
||||||
|
else:
|
||||||
|
X_data, y_DBP, y_SBP = self.load_data()
|
||||||
|
train_dataset, val_dataset = self.split_data(X_data, y_SBP, y_DBP)
|
||||||
|
|
||||||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
|
train_dataset, num_replicas=world_size, rank=rank, shuffle=True
|
||||||
|
)
|
||||||
|
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
|
val_dataset, num_replicas=world_size, rank=rank, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
sampler=train_sampler,
|
||||||
|
collate_fn=custom_collate_fn,
|
||||||
|
)
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
sampler=val_sampler,
|
||||||
|
collate_fn=custom_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataloader, val_dataloader, train_sampler, val_sampler
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
#
|
||||||
|
# data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=32,data_type='MIMIC')
|
||||||
|
# train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||||
|
#
|
||||||
|
# for i, (X, y_SBP, y_DBP) in enumerate(train_dataloader):
|
||||||
|
# print(f"Batch {i+1}: X.shape={X.shape }, y_SBP.shape={y_SBP.shape}, y_DBP.shape={y_DBP.shape}")
|
||||||
|
# if i == 2:
|
||||||
|
# break
|
|
@ -0,0 +1,195 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataloader import BPDataLoader
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 定义TensorBoard写入器
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
# 定义训练参数
|
||||||
|
max_epochs = 100
|
||||||
|
batch_size = 1024
|
||||||
|
warmup_epochs = 10
|
||||||
|
lr = 0.0005
|
||||||
|
|
||||||
|
def train(gpu, args):
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
|
rank = args.nr * args.gpus + gpu
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
# init_method="env://",
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置当前 GPU 设备
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
|
||||||
|
# 创建模型并移动到对应 GPU
|
||||||
|
model = LSTMModel().to(gpu)
|
||||||
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
||||||
|
|
||||||
|
# 定义损失函数
|
||||||
|
criterion = nn.MSELoss().to(gpu)
|
||||||
|
|
||||||
|
# 定义优化器
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# 定义学习率调度器
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||||
|
|
||||||
|
# 准备数据加载器
|
||||||
|
data_type = 'MIMIC_full'
|
||||||
|
|
||||||
|
# #检查模型存放路径是否存在
|
||||||
|
# if not os.path.exists(f'weights'):
|
||||||
|
# os.makedirs(f'weights')
|
||||||
|
# if not os.path.exists(f'weights/{data_type}'):
|
||||||
|
# os.makedirs(f'weights/{data_type}')
|
||||||
|
|
||||||
|
|
||||||
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type=data_type)
|
||||||
|
train_loader, val_loader ,train_sampler, val_sampler = data_loader.get_distributed_dataloaders(rank=gpu, world_size=args.world_size)
|
||||||
|
|
||||||
|
|
||||||
|
best_val_loss_sbp = float('inf')
|
||||||
|
best_val_loss_dbp = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
if epoch < warmup_epochs:
|
||||||
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = warmup_lr
|
||||||
|
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
train_loss = run_train(model, train_loader, optimizer, criterion, epoch, gpu)
|
||||||
|
|
||||||
|
val_loss_sbp, val_loss_dbp = run_evaluate(model, val_loader, criterion, gpu)
|
||||||
|
|
||||||
|
if gpu == 0:
|
||||||
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||||
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||||
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||||
|
|
||||||
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||||
|
best_val_loss_sbp = val_loss_sbp
|
||||||
|
best_val_loss_dbp = val_loss_dbp
|
||||||
|
torch.save(model.module, f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||||
|
|
||||||
|
torch.save(model.module, f'weights/{data_type}/last.pth')
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
def reduce_tensor(tensor):
|
||||||
|
rt = tensor.clone()
|
||||||
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||||
|
rt /= dist.get_world_size()
|
||||||
|
return rt
|
||||||
|
|
||||||
|
def run_train(model, dataloader, optimizer, criterion, epoch, gpu):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
pbar = tqdm(dataloader, total=len(dataloader), disable=(gpu != 0),desc=f"GPU{gpu} Epoch {epoch+1}/{max_epochs}")
|
||||||
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||||
|
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||||
|
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
loss = loss_sbp + loss_dbp
|
||||||
|
reduced_loss = reduce_tensor(loss)
|
||||||
|
|
||||||
|
reduced_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += reduced_loss.item()
|
||||||
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||||
|
|
||||||
|
return running_loss / len(dataloader)
|
||||||
|
|
||||||
|
def run_evaluate(model, dataloader, criterion, gpu):
|
||||||
|
model.eval()
|
||||||
|
running_loss_sbp = 0.0
|
||||||
|
running_loss_dbp = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||||
|
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||||
|
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
reduced_loss_sbp = reduce_tensor(loss_sbp)
|
||||||
|
reduced_loss_dbp = reduce_tensor(loss_dbp)
|
||||||
|
|
||||||
|
running_loss_sbp += reduced_loss_sbp.item()
|
||||||
|
running_loss_dbp += reduced_loss_dbp.item()
|
||||||
|
|
||||||
|
eval_loss_sbp = running_loss_sbp / len(dataloader)
|
||||||
|
eval_loss_dbp = running_loss_dbp / len(dataloader)
|
||||||
|
|
||||||
|
return eval_loss_sbp, eval_loss_dbp
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--nr", type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if ngpus_per_node>4:
|
||||||
|
ngpus_per_node = 4
|
||||||
|
|
||||||
|
args.world_size = ngpus_per_node
|
||||||
|
args.gpus = max(ngpus_per_node, 1)
|
||||||
|
mp.spawn(train, nprocs=args.gpus, args=(args,))
|
||||||
|
|
||||||
|
#检查模型存放路径是否存在
|
||||||
|
def check_path(data_type):
|
||||||
|
if not os.path.exists(f'weights'):
|
||||||
|
os.makedirs(f'weights')
|
||||||
|
if not os.path.exists(f'weights/{data_type}'):
|
||||||
|
os.makedirs(f'weights/{data_type}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
check_path('MIMIC_full')
|
||||||
|
main()
|
|
@ -0,0 +1,200 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataloader import BPDataLoader
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 定义TensorBoard写入器
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
# 定义训练参数
|
||||||
|
max_epochs = 100
|
||||||
|
batch_size = 1024
|
||||||
|
warmup_epochs = 10
|
||||||
|
lr = 0.0005
|
||||||
|
|
||||||
|
def train(gpu, args):
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
|
rank = args.nr * args.gpus + gpu
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
# init_method="env://",
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置当前 GPU 设备
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
|
||||||
|
# 创建模型并移动到对应 GPU
|
||||||
|
model = LSTMModel().to(gpu)
|
||||||
|
|
||||||
|
w = torch.load(r'weights/MIMIC_full/best_90_lstm_model_sbp267.4183_dbp89.7367.pth',
|
||||||
|
map_location=torch.device(f'cuda:{gpu}'))
|
||||||
|
|
||||||
|
# 加载权重
|
||||||
|
model.load_state_dict(w.state_dict())
|
||||||
|
|
||||||
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 定义损失函数
|
||||||
|
criterion = nn.MSELoss().to(gpu)
|
||||||
|
|
||||||
|
# 定义优化器
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# 定义学习率调度器
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||||
|
|
||||||
|
# 准备数据加载器
|
||||||
|
data_type = 'UKL'
|
||||||
|
|
||||||
|
# #检查模型存放路径是否存在
|
||||||
|
# if not os.path.exists(f'weights'):
|
||||||
|
# os.makedirs(f'weights')
|
||||||
|
# if not os.path.exists(f'weights/{data_type}'):
|
||||||
|
# os.makedirs(f'weights/{data_type}')
|
||||||
|
|
||||||
|
|
||||||
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type=data_type)
|
||||||
|
train_loader, val_loader ,train_sampler, val_sampler = data_loader.get_distributed_dataloaders(rank=gpu, world_size=args.world_size)
|
||||||
|
|
||||||
|
|
||||||
|
best_val_loss_sbp = float('inf')
|
||||||
|
best_val_loss_dbp = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
if epoch < warmup_epochs:
|
||||||
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = warmup_lr
|
||||||
|
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
train_loss = run_train(model, train_loader, optimizer, criterion, epoch, gpu)
|
||||||
|
|
||||||
|
val_loss_sbp, val_loss_dbp = run_evaluate(model, val_loader, criterion, gpu)
|
||||||
|
|
||||||
|
if gpu == 0:
|
||||||
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||||
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||||
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||||
|
|
||||||
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||||
|
best_val_loss_sbp = val_loss_sbp
|
||||||
|
best_val_loss_dbp = val_loss_dbp
|
||||||
|
torch.save(model.module, f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||||
|
|
||||||
|
torch.save(model.module, f'weights/{data_type}/last.pth')
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
def reduce_tensor(tensor):
|
||||||
|
rt = tensor.clone()
|
||||||
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||||
|
rt /= dist.get_world_size()
|
||||||
|
return rt
|
||||||
|
|
||||||
|
def run_train(model, dataloader, optimizer, criterion, epoch, gpu):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
pbar = tqdm(dataloader, total=len(dataloader), disable=(gpu != 0),desc=f"GPU{gpu} Epoch {epoch+1}/{max_epochs}")
|
||||||
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||||
|
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||||
|
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
loss = loss_sbp + loss_dbp
|
||||||
|
reduced_loss = reduce_tensor(loss)
|
||||||
|
|
||||||
|
reduced_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += reduced_loss.item()
|
||||||
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||||
|
|
||||||
|
return running_loss / len(dataloader)
|
||||||
|
|
||||||
|
def run_evaluate(model, dataloader, criterion, gpu):
|
||||||
|
model.eval()
|
||||||
|
running_loss_sbp = 0.0
|
||||||
|
running_loss_dbp = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||||
|
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||||
|
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
reduced_loss_sbp = reduce_tensor(loss_sbp)
|
||||||
|
reduced_loss_dbp = reduce_tensor(loss_dbp)
|
||||||
|
|
||||||
|
running_loss_sbp += reduced_loss_sbp.item()
|
||||||
|
running_loss_dbp += reduced_loss_dbp.item()
|
||||||
|
|
||||||
|
eval_loss_sbp = running_loss_sbp / len(dataloader)
|
||||||
|
eval_loss_dbp = running_loss_dbp / len(dataloader)
|
||||||
|
|
||||||
|
return eval_loss_sbp, eval_loss_dbp
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--nr", type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if ngpus_per_node>4:
|
||||||
|
ngpus_per_node = 4
|
||||||
|
|
||||||
|
args.world_size = ngpus_per_node
|
||||||
|
args.gpus = max(ngpus_per_node, 1)
|
||||||
|
mp.spawn(train, nprocs=args.gpus, args=(args,))
|
||||||
|
|
||||||
|
def check_path(data_type):
|
||||||
|
if not os.path.exists(f'weights'):
|
||||||
|
os.makedirs(f'weights')
|
||||||
|
if not os.path.exists(f'weights/{data_type}'):
|
||||||
|
os.makedirs(f'weights/{data_type}')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
check_path('UKL')
|
||||||
|
main()
|
|
@ -0,0 +1,198 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataloader import BPDataLoader
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 定义TensorBoard写入器
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
# 定义训练参数
|
||||||
|
max_epochs = 100
|
||||||
|
batch_size = 1024
|
||||||
|
warmup_epochs = 10
|
||||||
|
lr = 0.0005
|
||||||
|
|
||||||
|
def train(gpu, args):
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
|
rank = args.nr * args.gpus + gpu
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
# init_method="env://",
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置当前 GPU 设备
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
|
||||||
|
# 创建模型并移动到对应 GPU
|
||||||
|
model = LSTMModel().to(gpu)
|
||||||
|
|
||||||
|
|
||||||
|
w = torch.load(r'weights/UKL/best_99_lstm_model_sbp90.9980_dbp51.0640.pth',
|
||||||
|
map_location=torch.device(f'cuda:{gpu}'))
|
||||||
|
|
||||||
|
# 加载权重
|
||||||
|
model.load_state_dict(w.state_dict())
|
||||||
|
|
||||||
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
||||||
|
|
||||||
|
# 定义损失函数
|
||||||
|
criterion = nn.MSELoss().to(gpu)
|
||||||
|
|
||||||
|
# 定义优化器
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# 定义学习率调度器
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||||
|
|
||||||
|
# 准备数据加载器
|
||||||
|
data_type = 'X'
|
||||||
|
|
||||||
|
# #检查模型存放路径是否存在
|
||||||
|
# if not os.path.exists(f'weights'):
|
||||||
|
# os.makedirs(f'weights')
|
||||||
|
# if not os.path.exists(f'weights/{data_type}'):
|
||||||
|
# os.makedirs(f'weights/{data_type}')
|
||||||
|
|
||||||
|
|
||||||
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type=data_type)
|
||||||
|
train_loader, val_loader ,train_sampler, val_sampler = data_loader.get_distributed_dataloaders(rank=gpu, world_size=args.world_size)
|
||||||
|
|
||||||
|
|
||||||
|
best_val_loss_sbp = float('inf')
|
||||||
|
best_val_loss_dbp = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
if epoch < warmup_epochs:
|
||||||
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = warmup_lr
|
||||||
|
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
train_loss = run_train(model, train_loader, optimizer, criterion, epoch, gpu)
|
||||||
|
|
||||||
|
val_loss_sbp, val_loss_dbp = run_evaluate(model, val_loader, criterion, gpu)
|
||||||
|
|
||||||
|
if gpu == 0:
|
||||||
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||||
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||||
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||||
|
|
||||||
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||||
|
best_val_loss_sbp = val_loss_sbp
|
||||||
|
best_val_loss_dbp = val_loss_dbp
|
||||||
|
torch.save(model.module, f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||||
|
|
||||||
|
torch.save(model.module, f'weights/{data_type}/last.pth')
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
def reduce_tensor(tensor):
|
||||||
|
rt = tensor.clone()
|
||||||
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||||
|
rt /= dist.get_world_size()
|
||||||
|
return rt
|
||||||
|
|
||||||
|
def run_train(model, dataloader, optimizer, criterion, epoch, gpu):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
pbar = tqdm(dataloader, total=len(dataloader), disable=(gpu != 0),desc=f"GPU{gpu} Epoch {epoch+1}/{max_epochs}")
|
||||||
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||||
|
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||||
|
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
loss = loss_sbp + loss_dbp
|
||||||
|
reduced_loss = reduce_tensor(loss)
|
||||||
|
|
||||||
|
reduced_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += reduced_loss.item()
|
||||||
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||||
|
|
||||||
|
return running_loss / len(dataloader)
|
||||||
|
|
||||||
|
def run_evaluate(model, dataloader, criterion, gpu):
|
||||||
|
model.eval()
|
||||||
|
running_loss_sbp = 0.0
|
||||||
|
running_loss_dbp = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||||
|
inputs = inputs.cuda(gpu, non_blocking=True)
|
||||||
|
sbp_labels = sbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
dbp_labels = dbp_labels.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
reduced_loss_sbp = reduce_tensor(loss_sbp)
|
||||||
|
reduced_loss_dbp = reduce_tensor(loss_dbp)
|
||||||
|
|
||||||
|
running_loss_sbp += reduced_loss_sbp.item()
|
||||||
|
running_loss_dbp += reduced_loss_dbp.item()
|
||||||
|
|
||||||
|
eval_loss_sbp = running_loss_sbp / len(dataloader)
|
||||||
|
eval_loss_dbp = running_loss_dbp / len(dataloader)
|
||||||
|
|
||||||
|
return eval_loss_sbp, eval_loss_dbp
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--nr", type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if ngpus_per_node>4:
|
||||||
|
ngpus_per_node = 4
|
||||||
|
args.world_size = ngpus_per_node
|
||||||
|
args.gpus = max(ngpus_per_node, 1)
|
||||||
|
mp.spawn(train, nprocs=args.gpus, args=(args,))
|
||||||
|
|
||||||
|
def check_path(data_type):
|
||||||
|
if not os.path.exists(f'weights'):
|
||||||
|
os.makedirs(f'weights')
|
||||||
|
if not os.path.exists(f'weights/{data_type}'):
|
||||||
|
os.makedirs(f'weights/{data_type}')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
check_path('X')
|
||||||
|
main()
|
|
@ -0,0 +1,62 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMModel(nn.Module):
|
||||||
|
def __init__(self, input_size=1, hidden_size=128, output_size=2):
|
||||||
|
super(LSTMModel, self).__init__()
|
||||||
|
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
self.conv1d = nn.Conv1d(input_size, 64, kernel_size=5, padding=2)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.lstm1 = nn.LSTM(64, hidden_size, bidirectional=True, batch_first=True)
|
||||||
|
self.lstm2 = nn.LSTM(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
|
||||||
|
self.lstm3 = nn.LSTM(hidden_size * 2, 64, bidirectional=False, batch_first=True)
|
||||||
|
self.fc1 = nn.Linear(64, 512)
|
||||||
|
self.fc2 = nn.Linear(512, 256)
|
||||||
|
self.fc3 = nn.Linear(256, 128)
|
||||||
|
self.fc_sbp = nn.Linear(128, 1)
|
||||||
|
self.fc_dbp = nn.Linear(128, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 将输入传递给Conv1d层
|
||||||
|
x = self.conv1d(x.permute(0, 2, 1).contiguous())
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
# 将输入传递给LSTM层
|
||||||
|
x, _ = self.lstm1(x)
|
||||||
|
x, _ = self.lstm2(x)
|
||||||
|
x, _ = self.lstm3(x)
|
||||||
|
|
||||||
|
# 只使用最后一个时间步的输出
|
||||||
|
x = x[:, -1, :]
|
||||||
|
|
||||||
|
# 将LSTM输出传递给全连接层
|
||||||
|
x = self.relu(self.fc1(x))
|
||||||
|
x = self.relu(self.fc2(x))
|
||||||
|
x = self.relu(self.fc3(x))
|
||||||
|
|
||||||
|
# 从两个Linear输出最终结果
|
||||||
|
sbp = self.fc_sbp(x)
|
||||||
|
dbp = self.fc_dbp(x)
|
||||||
|
|
||||||
|
return sbp, dbp
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 创建模型实例
|
||||||
|
model = LSTMModel()
|
||||||
|
|
||||||
|
# 定义示例输入
|
||||||
|
batch_size = 64
|
||||||
|
seq_len = 1250
|
||||||
|
input_size = 1
|
||||||
|
input_data = torch.randn(batch_size, seq_len, input_size)
|
||||||
|
|
||||||
|
# 将输入数据传递给模型
|
||||||
|
sbp, dbp = model(input_data)
|
||||||
|
print(sbp.shape, dbp.shape) # 输出: torch.Size([64, 1]) torch.Size([64, 1])
|
|
@ -0,0 +1,138 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataloader import BPDataLoader
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
# 定义模型
|
||||||
|
model = LSTMModel()
|
||||||
|
|
||||||
|
#定义训练参数
|
||||||
|
max_epochs = 100
|
||||||
|
batch_size= 1024
|
||||||
|
warmup_epochs = 10
|
||||||
|
lr = 0.0005
|
||||||
|
|
||||||
|
# 定义损失函数和优化器
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# 定义学习率调度器
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||||
|
|
||||||
|
# 定义TensorBoard写入器
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
# 训练函数
|
||||||
|
def train(model, dataloader, epoch, device,batch_size):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{max_epochs}")
|
||||||
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
sbp_labels = sbp_labels.to(device)
|
||||||
|
dbp_labels = dbp_labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
loss = loss_sbp + loss_dbp
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
writer.add_scalar("Loss/train", running_loss / len(dataloader)/ batch_size, epoch)
|
||||||
|
|
||||||
|
return running_loss / len(dataloader) / batch_size
|
||||||
|
|
||||||
|
# 评估函数
|
||||||
|
def evaluate(model, dataloader, device,batch_size):
|
||||||
|
model.eval()
|
||||||
|
running_loss_sbp = 0.0
|
||||||
|
running_loss_dbp = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
sbp_labels = sbp_labels.to(device)
|
||||||
|
dbp_labels = dbp_labels.to(device)
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
running_loss_sbp += loss_sbp.item()
|
||||||
|
running_loss_dbp += loss_dbp.item()
|
||||||
|
|
||||||
|
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
||||||
|
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
||||||
|
|
||||||
|
return eval_loss_sbp, eval_loss_dbp
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
data_type='MIMIC_full'
|
||||||
|
|
||||||
|
#判断权重保存目录是否存在,不存在则创建
|
||||||
|
if not os.path.exists('weights'):
|
||||||
|
os.makedirs('weights')
|
||||||
|
#在其中创建data_type同名子文件夹
|
||||||
|
os.makedirs(os.path.join('weights',data_type))
|
||||||
|
else:
|
||||||
|
#判断子文件夹是否存在
|
||||||
|
if not os.path.exists(os.path.join('weights',data_type)):
|
||||||
|
os.makedirs(os.path.join('weights',data_type))
|
||||||
|
|
||||||
|
|
||||||
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size,data_type=data_type)
|
||||||
|
|
||||||
|
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||||
|
|
||||||
|
best_val_loss_sbp = float('inf')
|
||||||
|
best_val_loss_dbp = float('inf')
|
||||||
|
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
if epoch < warmup_epochs:
|
||||||
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = warmup_lr
|
||||||
|
|
||||||
|
train_loss = train(model, train_dataloader, epoch, device,batch_size)
|
||||||
|
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device,batch_size)
|
||||||
|
|
||||||
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||||
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||||
|
|
||||||
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||||
|
best_val_loss_sbp = val_loss_sbp
|
||||||
|
best_val_loss_dbp = val_loss_dbp
|
||||||
|
torch.save(model.state_dict(), f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||||
|
|
||||||
|
torch.save(model.state_dict(),
|
||||||
|
f'weights/{data_type}/last.pth')
|
||||||
|
|
||||||
|
writer.close()
|
|
@ -0,0 +1,144 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataloader import BPDataLoader
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
# 定义模型
|
||||||
|
model = LSTMModel()
|
||||||
|
|
||||||
|
# 加载权重
|
||||||
|
model.load_state_dict(torch.load(r'weights/MIMIC/best_27_lstm_model_sbp1.4700_dbp0.4493.pth'))
|
||||||
|
|
||||||
|
# 定义训练参数
|
||||||
|
max_epochs = 100
|
||||||
|
batch_size= 1024
|
||||||
|
warmup_epochs = 10
|
||||||
|
lr = 0.0005
|
||||||
|
|
||||||
|
# 定义损失函数和优化器
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# 定义学习率调度器
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||||
|
|
||||||
|
# 定义TensorBoard写入器
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
|
||||||
|
# 训练函数
|
||||||
|
def train(model, dataloader, epoch, device, batch_size):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch + 1}/{max_epochs}")
|
||||||
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
sbp_labels = sbp_labels.to(device)
|
||||||
|
dbp_labels = dbp_labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
loss = loss_sbp + loss_dbp
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
writer.add_scalar("Loss/train", running_loss / len(dataloader) / batch_size, epoch)
|
||||||
|
|
||||||
|
return running_loss / len(dataloader) / batch_size
|
||||||
|
|
||||||
|
|
||||||
|
# 评估函数
|
||||||
|
def evaluate(model, dataloader, device, batch_size):
|
||||||
|
model.eval()
|
||||||
|
running_loss_sbp = 0.0
|
||||||
|
running_loss_dbp = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
sbp_labels = sbp_labels.to(device)
|
||||||
|
dbp_labels = dbp_labels.to(device)
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
running_loss_sbp += loss_sbp.item()
|
||||||
|
running_loss_dbp += loss_dbp.item()
|
||||||
|
|
||||||
|
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
||||||
|
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
||||||
|
|
||||||
|
return eval_loss_sbp, eval_loss_dbp
|
||||||
|
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
data_type = 'UKL'
|
||||||
|
|
||||||
|
# 判断权重保存目录是否存在,不存在则创建
|
||||||
|
if not os.path.exists('weights'):
|
||||||
|
os.makedirs('weights')
|
||||||
|
# 在其中创建data_type同名子文件夹
|
||||||
|
os.makedirs(os.path.join('weights', data_type))
|
||||||
|
else:
|
||||||
|
# 判断子文件夹是否存在
|
||||||
|
if not os.path.exists(os.path.join('weights', data_type)):
|
||||||
|
os.makedirs(os.path.join('weights', data_type))
|
||||||
|
|
||||||
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type='UKL')
|
||||||
|
|
||||||
|
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||||
|
|
||||||
|
best_val_loss_sbp = float('inf')
|
||||||
|
best_val_loss_dbp = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
if epoch < warmup_epochs:
|
||||||
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = warmup_lr
|
||||||
|
|
||||||
|
train_loss = train(model, train_dataloader, epoch, device, batch_size)
|
||||||
|
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device, batch_size)
|
||||||
|
|
||||||
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||||
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch + 1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||||
|
|
||||||
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||||
|
best_val_loss_sbp = val_loss_sbp
|
||||||
|
best_val_loss_dbp = val_loss_dbp
|
||||||
|
torch.save(model.state_dict(),
|
||||||
|
f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||||
|
|
||||||
|
torch.save(model.state_dict(),
|
||||||
|
f'weights/{data_type}/last.pth')
|
||||||
|
|
||||||
|
writer.close()
|
|
@ -0,0 +1,144 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataloader import BPDataLoader
|
||||||
|
from models.lstm import LSTMModel
|
||||||
|
|
||||||
|
# 定义模型
|
||||||
|
model = LSTMModel()
|
||||||
|
|
||||||
|
# 加载权重
|
||||||
|
model.load_state_dict(torch.load(r'weights/UKL/best_28_lstm_model_sbp0.3520_dbp0.2052.pth'))
|
||||||
|
|
||||||
|
# 定义训练参数
|
||||||
|
max_epochs = 100
|
||||||
|
batch_size= 1024
|
||||||
|
warmup_epochs = 10
|
||||||
|
lr = 0.0005
|
||||||
|
|
||||||
|
# 定义损失函数和优化器
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# 定义学习率调度器
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
||||||
|
|
||||||
|
# 定义TensorBoard写入器
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
|
||||||
|
# 训练函数
|
||||||
|
def train(model, dataloader, epoch, device, batch_size):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch + 1}/{max_epochs}")
|
||||||
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
sbp_labels = sbp_labels.to(device)
|
||||||
|
dbp_labels = dbp_labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
loss = loss_sbp + loss_dbp
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
writer.add_scalar("Loss/train", running_loss / len(dataloader) / batch_size, epoch)
|
||||||
|
|
||||||
|
return running_loss / len(dataloader) / batch_size
|
||||||
|
|
||||||
|
|
||||||
|
# 评估函数
|
||||||
|
def evaluate(model, dataloader, device, batch_size):
|
||||||
|
model.eval()
|
||||||
|
running_loss_sbp = 0.0
|
||||||
|
running_loss_dbp = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
sbp_labels = sbp_labels.to(device)
|
||||||
|
dbp_labels = dbp_labels.to(device)
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model(inputs)
|
||||||
|
|
||||||
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
||||||
|
dbp_outputs = dbp_outputs.squeeze(1)
|
||||||
|
|
||||||
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
||||||
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
||||||
|
|
||||||
|
running_loss_sbp += loss_sbp.item()
|
||||||
|
running_loss_dbp += loss_dbp.item()
|
||||||
|
|
||||||
|
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
||||||
|
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
||||||
|
|
||||||
|
return eval_loss_sbp, eval_loss_dbp
|
||||||
|
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
data_type = 'X'
|
||||||
|
|
||||||
|
# 判断权重保存目录是否存在,不存在则创建
|
||||||
|
if not os.path.exists('weights'):
|
||||||
|
os.makedirs('weights')
|
||||||
|
# 在其中创建data_type同名子文件夹
|
||||||
|
os.makedirs(os.path.join('weights', data_type))
|
||||||
|
else:
|
||||||
|
# 判断子文件夹是否存在
|
||||||
|
if not os.path.exists(os.path.join('weights', data_type)):
|
||||||
|
os.makedirs(os.path.join('weights', data_type))
|
||||||
|
|
||||||
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size, data_type='X')
|
||||||
|
|
||||||
|
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
||||||
|
|
||||||
|
best_val_loss_sbp = float('inf')
|
||||||
|
best_val_loss_dbp = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
if epoch < warmup_epochs:
|
||||||
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = warmup_lr
|
||||||
|
|
||||||
|
train_loss = train(model, train_dataloader, epoch, device, batch_size)
|
||||||
|
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device, batch_size)
|
||||||
|
|
||||||
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
||||||
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch + 1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
||||||
|
|
||||||
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
||||||
|
best_val_loss_sbp = val_loss_sbp
|
||||||
|
best_val_loss_dbp = val_loss_dbp
|
||||||
|
torch.save(model.state_dict(),
|
||||||
|
f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
||||||
|
|
||||||
|
torch.save(model.state_dict(),
|
||||||
|
f'weights/{data_type}/last.pth')
|
||||||
|
torch.cuda.is_available()
|
||||||
|
writer.close()
|
|
@ -0,0 +1,68 @@
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from BPApi import BPModel
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cap = cv2.VideoCapture(0) # 使用摄像头
|
||||||
|
|
||||||
|
#设置视频宽高
|
||||||
|
cap.set(3, 1920)
|
||||||
|
cap.set(4, 1080)
|
||||||
|
|
||||||
|
video_fs = cap.get(5)
|
||||||
|
# print(video_fs)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = BPModel(model_path=r'final/best.pth', fps=video_fs)
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
text = ["calculating..."]
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
# 检测人脸
|
||||||
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
faces = face_cascade.detectMultiScale(gray, 1.3, 5)
|
||||||
|
|
||||||
|
if faces is not None and len(faces) > 0:
|
||||||
|
# 将第一个人脸区域的图像截取
|
||||||
|
x, y, w, h = faces[0]
|
||||||
|
face = frame[y:y + h, x:x + w]
|
||||||
|
|
||||||
|
frames.append(face)
|
||||||
|
|
||||||
|
cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 255, 0), 2)
|
||||||
|
print(len(frames))
|
||||||
|
|
||||||
|
if len(frames) == 250:
|
||||||
|
|
||||||
|
sbp_outputs, dbp_outputs = model.predict(frames)
|
||||||
|
|
||||||
|
print(sbp_outputs, dbp_outputs)
|
||||||
|
|
||||||
|
text.clear()
|
||||||
|
text.append('SBP: {:.2f} mmHg'.format(sbp_outputs))
|
||||||
|
text.append('DBP: {:.2f} mmHg'.format(dbp_outputs))
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
# 去除列表最前面的100个元素
|
||||||
|
# frames=frames[50:]
|
||||||
|
|
||||||
|
for i, t in enumerate(text):
|
||||||
|
cv2.putText(frame, t, (10, 60 + i * 20), font, 0.6, (0, 255, 0), 2)
|
||||||
|
cv2.imshow('Blood Pressure Detection', frame)
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
if key == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Binary file not shown.
|
@ -0,0 +1,34 @@
|
||||||
|
# 基于视觉的表情识别系统
|
||||||
|
|
||||||
|
该项目是一个基于图像的表情识别系统。它使用MobileViT在人脸表情数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测表情类型,共支持8类表情。
|
||||||
|
|
||||||
|
## 核心文件
|
||||||
|
|
||||||
|
- `class_indices.json`: 包含表情类型标签和对应数值编码的映射。
|
||||||
|
- `predict_api.py`: 包含图像预测模型的加载、预处理和推理逻辑。
|
||||||
|
- `video.py`: 视频处理和可视化的主要脚本。
|
||||||
|
- `best.pth`: 训练的模型权重文件。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`、`Pillow`和`dlib`。
|
||||||
|
2. 运行`video.py`脚本。
|
||||||
|
3. 脚本将打开默认摄像头,开始人脸检测和表情预测。
|
||||||
|
4. 检测到的人脸周围会用矩形框标注,并显示预测的表情类型和置信度分数。
|
||||||
|
5. 按`q`键退出程序。
|
||||||
|
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
该项目使用MobileViT作为基础模型,对人脸表情图像数据集进行训练,以预测人脸图像的表情类型。模型输出包含8个值,分别对应各表情类型的概率。
|
||||||
|
|
||||||
|
### 数据集介绍
|
||||||
|
|
||||||
|
该项目使用的表情图像数据集来自网络开源数据,数据集包含35887张标注了皮肤病类型的人体皮肤图像。
|
||||||
|
|
||||||
|
## 算法流程
|
||||||
|
|
||||||
|
1. **人脸检测**: 使用Dlib库中的预训练人脸检测器在视频帧中检测人脸。
|
||||||
|
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||||
|
3. **推理**: 将预处理后的图像输入到预训练的Mobile-ViT模型中,获得不同表情类型的概率预测结果。
|
||||||
|
4. **后处理**: 选取概率最高的类别作为最终预测结果。
|
||||||
|
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的表情类型和置信度分数。
|
Binary file not shown.
|
@ -0,0 +1,11 @@
|
||||||
|
{
|
||||||
|
"0": "生气",
|
||||||
|
"1": "困惑",
|
||||||
|
"2": "厌恶",
|
||||||
|
"3": "恐惧",
|
||||||
|
"4": "快乐",
|
||||||
|
"5": "平静",
|
||||||
|
"6": "伤心",
|
||||||
|
"7": "害羞",
|
||||||
|
"8": "惊喜"
|
||||||
|
}
|
Binary file not shown.
|
@ -0,0 +1,562 @@
|
||||||
|
"""
|
||||||
|
original code from apple:
|
||||||
|
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union, Dict
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from transformer import TransformerEncoder
|
||||||
|
from model_config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(
|
||||||
|
v: Union[float, int],
|
||||||
|
divisor: Optional[int] = 8,
|
||||||
|
min_value: Optional[Union[float, int]] = None,
|
||||||
|
) -> Union[float, int]:
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by 8
|
||||||
|
It can be seen here:
|
||||||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
:param v:
|
||||||
|
:param divisor:
|
||||||
|
:param min_value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Applies a 2D convolution over an input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
|
||||||
|
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
|
||||||
|
groups (Optional[int]): Number of groups in convolution. Default: 1
|
||||||
|
bias (Optional[bool]): Use bias. Default: ``False``
|
||||||
|
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
|
||||||
|
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
|
||||||
|
Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
stride: Optional[Union[int, Tuple[int, int]]] = 1,
|
||||||
|
groups: Optional[int] = 1,
|
||||||
|
bias: Optional[bool] = False,
|
||||||
|
use_norm: Optional[bool] = True,
|
||||||
|
use_act: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size)
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride)
|
||||||
|
|
||||||
|
assert isinstance(kernel_size, Tuple)
|
||||||
|
assert isinstance(stride, Tuple)
|
||||||
|
|
||||||
|
padding = (
|
||||||
|
int((kernel_size[0] - 1) / 2),
|
||||||
|
int((kernel_size[1] - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
|
||||||
|
conv_layer = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
groups=groups,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(name="conv", module=conv_layer)
|
||||||
|
|
||||||
|
if use_norm:
|
||||||
|
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
|
||||||
|
block.add_module(name="norm", module=norm_layer)
|
||||||
|
|
||||||
|
if use_act:
|
||||||
|
act_layer = nn.SiLU()
|
||||||
|
block.add_module(name="act", module=act_layer)
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
"""
|
||||||
|
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
|
||||||
|
stride (int): Use convolutions with a stride. Default: 1
|
||||||
|
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
|
||||||
|
skip_connection (Optional[bool]): Use skip-connection. Default: True
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int,
|
||||||
|
expand_ratio: Union[int, float],
|
||||||
|
skip_connection: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
assert stride in [1, 2]
|
||||||
|
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
if expand_ratio != 1:
|
||||||
|
block.add_module(
|
||||||
|
name="exp_1x1",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
kernel_size=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(
|
||||||
|
name="conv_3x3",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
stride=stride,
|
||||||
|
kernel_size=3,
|
||||||
|
groups=hidden_dim
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(
|
||||||
|
name="red_1x1",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
use_act=False,
|
||||||
|
use_norm=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.exp = expand_ratio
|
||||||
|
self.stride = stride
|
||||||
|
self.use_res_connect = (
|
||||||
|
self.stride == 1 and in_channels == out_channels and skip_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.block(x)
|
||||||
|
else:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViTBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opts: command line arguments
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
||||||
|
transformer_dim (int): Input dimension to the transformer unit
|
||||||
|
ffn_dim (int): Dimension of the FFN block
|
||||||
|
n_transformer_blocks (int): Number of transformer blocks. Default: 2
|
||||||
|
head_dim (int): Head dimension in the multi-head attention. Default: 32
|
||||||
|
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
|
||||||
|
dropout (float): Dropout rate. Default: 0.0
|
||||||
|
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
|
||||||
|
patch_h (int): Patch height for unfolding operation. Default: 8
|
||||||
|
patch_w (int): Patch width for unfolding operation. Default: 8
|
||||||
|
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
||||||
|
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
|
||||||
|
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
transformer_dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
n_transformer_blocks: int = 2,
|
||||||
|
head_dim: int = 32,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_dropout: float = 0.0,
|
||||||
|
patch_h: int = 8,
|
||||||
|
patch_w: int = 8,
|
||||||
|
conv_ksize: Optional[int] = 3,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv_3x3_in = ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=conv_ksize,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
conv_1x1_in = ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=transformer_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
use_norm=False,
|
||||||
|
use_act=False
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_1x1_out = ConvLayer(
|
||||||
|
in_channels=transformer_dim,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
conv_3x3_out = ConvLayer(
|
||||||
|
in_channels=2 * in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=conv_ksize,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.local_rep = nn.Sequential()
|
||||||
|
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
||||||
|
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
||||||
|
|
||||||
|
assert transformer_dim % head_dim == 0
|
||||||
|
num_heads = transformer_dim // head_dim
|
||||||
|
|
||||||
|
global_rep = [
|
||||||
|
TransformerEncoder(
|
||||||
|
embed_dim=transformer_dim,
|
||||||
|
ffn_latent_dim=ffn_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_dropout=ffn_dropout
|
||||||
|
)
|
||||||
|
for _ in range(n_transformer_blocks)
|
||||||
|
]
|
||||||
|
global_rep.append(nn.LayerNorm(transformer_dim))
|
||||||
|
self.global_rep = nn.Sequential(*global_rep)
|
||||||
|
|
||||||
|
self.conv_proj = conv_1x1_out
|
||||||
|
self.fusion = conv_3x3_out
|
||||||
|
|
||||||
|
self.patch_h = patch_h
|
||||||
|
self.patch_w = patch_w
|
||||||
|
self.patch_area = self.patch_w * self.patch_h
|
||||||
|
|
||||||
|
self.cnn_in_dim = in_channels
|
||||||
|
self.cnn_out_dim = transformer_dim
|
||||||
|
self.n_heads = num_heads
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = attn_dropout
|
||||||
|
self.ffn_dropout = ffn_dropout
|
||||||
|
self.n_blocks = n_transformer_blocks
|
||||||
|
self.conv_ksize = conv_ksize
|
||||||
|
|
||||||
|
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||||
|
patch_w, patch_h = self.patch_w, self.patch_h
|
||||||
|
patch_area = patch_w * patch_h
|
||||||
|
batch_size, in_channels, orig_h, orig_w = x.shape
|
||||||
|
|
||||||
|
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
||||||
|
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
||||||
|
|
||||||
|
interpolate = False
|
||||||
|
if new_w != orig_w or new_h != orig_h:
|
||||||
|
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||||
|
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||||
|
interpolate = True
|
||||||
|
|
||||||
|
# number of patches along width and height
|
||||||
|
num_patch_w = new_w // patch_w # n_w
|
||||||
|
num_patch_h = new_h // patch_h # n_h
|
||||||
|
num_patches = num_patch_h * num_patch_w # N
|
||||||
|
|
||||||
|
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
info_dict = {
|
||||||
|
"orig_size": (orig_h, orig_w),
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"interpolate": interpolate,
|
||||||
|
"total_patches": num_patches,
|
||||||
|
"num_patches_w": num_patch_w,
|
||||||
|
"num_patches_h": num_patch_h,
|
||||||
|
}
|
||||||
|
|
||||||
|
return x, info_dict
|
||||||
|
|
||||||
|
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
|
||||||
|
n_dim = x.dim()
|
||||||
|
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
||||||
|
x.shape
|
||||||
|
)
|
||||||
|
# [BP, N, C] --> [B, P, N, C]
|
||||||
|
x = x.contiguous().view(
|
||||||
|
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, pixels, num_patches, channels = x.size()
|
||||||
|
num_patch_h = info_dict["num_patches_h"]
|
||||||
|
num_patch_w = info_dict["num_patches_w"]
|
||||||
|
|
||||||
|
# [B, P, N, C] -> [B, C, N, P]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
|
||||||
|
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
|
||||||
|
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
|
||||||
|
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
|
||||||
|
if info_dict["interpolate"]:
|
||||||
|
x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=info_dict["orig_size"],
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
res = x
|
||||||
|
|
||||||
|
fm = self.local_rep(x)
|
||||||
|
|
||||||
|
# convert feature map to patches
|
||||||
|
patches, info_dict = self.unfolding(fm)
|
||||||
|
|
||||||
|
# learn global representations
|
||||||
|
for transformer_layer in self.global_rep:
|
||||||
|
patches = transformer_layer(patches)
|
||||||
|
|
||||||
|
# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
|
||||||
|
fm = self.folding(x=patches, info_dict=info_dict)
|
||||||
|
|
||||||
|
fm = self.conv_proj(fm)
|
||||||
|
|
||||||
|
fm = self.fusion(torch.cat((res, fm), dim=1))
|
||||||
|
return fm
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViT(nn.Module):
|
||||||
|
"""
|
||||||
|
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||||
|
"""
|
||||||
|
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
image_channels = 3
|
||||||
|
out_channels = 16
|
||||||
|
|
||||||
|
self.conv_1 = ConvLayer(
|
||||||
|
in_channels=image_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
|
||||||
|
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
|
||||||
|
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
|
||||||
|
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
|
||||||
|
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
|
||||||
|
|
||||||
|
exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
|
||||||
|
self.conv_1x1_exp = ConvLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=exp_channels,
|
||||||
|
kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = nn.Sequential()
|
||||||
|
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
|
||||||
|
self.classifier.add_module(name="flatten", module=nn.Flatten())
|
||||||
|
if 0.0 < model_cfg["cls_dropout"] < 1.0:
|
||||||
|
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
|
||||||
|
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
|
||||||
|
|
||||||
|
# weight init
|
||||||
|
self.apply(self.init_parameters)
|
||||||
|
|
||||||
|
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||||
|
block_type = cfg.get("block_type", "mobilevit")
|
||||||
|
if block_type.lower() == "mobilevit":
|
||||||
|
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
|
||||||
|
else:
|
||||||
|
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||||
|
output_channels = cfg.get("out_channels")
|
||||||
|
num_blocks = cfg.get("num_blocks", 2)
|
||||||
|
expand_ratio = cfg.get("expand_ratio", 4)
|
||||||
|
block = []
|
||||||
|
|
||||||
|
for i in range(num_blocks):
|
||||||
|
stride = cfg.get("stride", 1) if i == 0 else 1
|
||||||
|
|
||||||
|
layer = InvertedResidual(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channels,
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=expand_ratio
|
||||||
|
)
|
||||||
|
block.append(layer)
|
||||||
|
input_channel = output_channels
|
||||||
|
|
||||||
|
return nn.Sequential(*block), input_channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
|
||||||
|
stride = cfg.get("stride", 1)
|
||||||
|
block = []
|
||||||
|
|
||||||
|
if stride == 2:
|
||||||
|
layer = InvertedResidual(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=cfg.get("out_channels"),
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=cfg.get("mv_expand_ratio", 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
block.append(layer)
|
||||||
|
input_channel = cfg.get("out_channels")
|
||||||
|
|
||||||
|
transformer_dim = cfg["transformer_channels"]
|
||||||
|
ffn_dim = cfg.get("ffn_dim")
|
||||||
|
num_heads = cfg.get("num_heads", 4)
|
||||||
|
head_dim = transformer_dim // num_heads
|
||||||
|
|
||||||
|
if transformer_dim % head_dim != 0:
|
||||||
|
raise ValueError("Transformer input dimension should be divisible by head dimension. "
|
||||||
|
"Got {} and {}.".format(transformer_dim, head_dim))
|
||||||
|
|
||||||
|
block.append(MobileViTBlock(
|
||||||
|
in_channels=input_channel,
|
||||||
|
transformer_dim=transformer_dim,
|
||||||
|
ffn_dim=ffn_dim,
|
||||||
|
n_transformer_blocks=cfg.get("transformer_blocks", 1),
|
||||||
|
patch_h=cfg.get("patch_h", 2),
|
||||||
|
patch_w=cfg.get("patch_w", 2),
|
||||||
|
dropout=cfg.get("dropout", 0.1),
|
||||||
|
ffn_dropout=cfg.get("ffn_dropout", 0.0),
|
||||||
|
attn_dropout=cfg.get("attn_dropout", 0.1),
|
||||||
|
head_dim=head_dim,
|
||||||
|
conv_ksize=3
|
||||||
|
))
|
||||||
|
|
||||||
|
return nn.Sequential(*block), input_channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_parameters(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.Linear,)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.conv_1(x)
|
||||||
|
x = self.layer_1(x)
|
||||||
|
x = self.layer_2(x)
|
||||||
|
|
||||||
|
x = self.layer_3(x)
|
||||||
|
x = self.layer_4(x)
|
||||||
|
x = self.layer_5(x)
|
||||||
|
x = self.conv_1x1_exp(x)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_xx_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
|
||||||
|
config = get_config("xx_small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_x_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
|
||||||
|
config = get_config("x_small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
|
||||||
|
config = get_config("small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
|
@ -0,0 +1,176 @@
|
||||||
|
def get_config(mode: str = "xxs") -> dict:
|
||||||
|
if mode == "xx_small":
|
||||||
|
mv2_exp_mult = 2
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 16,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 24,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 48,
|
||||||
|
"transformer_channels": 64,
|
||||||
|
"ffn_dim": 128,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2, # 8,
|
||||||
|
"patch_w": 2, # 8,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 64,
|
||||||
|
"transformer_channels": 80,
|
||||||
|
"ffn_dim": 160,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2, # 4,
|
||||||
|
"patch_w": 2, # 4,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 80,
|
||||||
|
"transformer_channels": 96,
|
||||||
|
"ffn_dim": 192,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
elif mode == "x_small":
|
||||||
|
mv2_exp_mult = 4
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 32,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 48,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 64,
|
||||||
|
"transformer_channels": 96,
|
||||||
|
"ffn_dim": 192,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 80,
|
||||||
|
"transformer_channels": 120,
|
||||||
|
"ffn_dim": 240,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 96,
|
||||||
|
"transformer_channels": 144,
|
||||||
|
"ffn_dim": 288,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
elif mode == "small":
|
||||||
|
mv2_exp_mult = 4
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 32,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 64,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 96,
|
||||||
|
"transformer_channels": 144,
|
||||||
|
"ffn_dim": 288,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 128,
|
||||||
|
"transformer_channels": 192,
|
||||||
|
"ffn_dim": 384,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 160,
|
||||||
|
"transformer_channels": 240,
|
||||||
|
"ffn_dim": 480,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
||||||
|
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
||||||
|
|
||||||
|
return config
|
|
@ -0,0 +1,38 @@
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataSet(Dataset):
|
||||||
|
"""自定义数据集"""
|
||||||
|
|
||||||
|
def __init__(self, images_path: list, images_class: list, transform=None):
|
||||||
|
self.images_path = images_path
|
||||||
|
self.images_class = images_class
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images_path)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img = Image.open(self.images_path[item])
|
||||||
|
# RGB为彩色图片,L为灰度图片
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
# img = img.convert('RGB')
|
||||||
|
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
||||||
|
label = self.images_class[item]
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
# 官方实现的default_collate可以参考
|
||||||
|
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
||||||
|
images, labels = tuple(zip(*batch))
|
||||||
|
|
||||||
|
images = torch.stack(images, dim=0)
|
||||||
|
labels = torch.as_tensor(labels)
|
||||||
|
return images, labels
|
|
@ -0,0 +1,61 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from model import mobile_vit_small as create_model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
img_size = 224
|
||||||
|
data_transform = transforms.Compose(
|
||||||
|
[transforms.Resize(int(img_size * 1.14)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
# load image
|
||||||
|
img_path = "../tulip.jpg"
|
||||||
|
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||||
|
img = Image.open(img_path)
|
||||||
|
plt.imshow(img)
|
||||||
|
# [N, C, H, W]
|
||||||
|
img = data_transform(img)
|
||||||
|
# expand batch dimension
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# read class_indict
|
||||||
|
json_path = './class_indices.json'
|
||||||
|
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||||
|
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
class_indict = json.load(f)
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = create_model(num_classes=5).to(device)
|
||||||
|
# load model weights
|
||||||
|
model_weight_path = "./weights/best_model.pth"
|
||||||
|
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# predict class
|
||||||
|
output = torch.squeeze(model(img.to(device))).cpu()
|
||||||
|
predict = torch.softmax(output, dim=0)
|
||||||
|
predict_cla = torch.argmax(predict).numpy()
|
||||||
|
|
||||||
|
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||||
|
predict[predict_cla].numpy())
|
||||||
|
plt.title(print_res)
|
||||||
|
for i in range(len(predict)):
|
||||||
|
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||||
|
predict[i].numpy()))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,68 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from model import mobile_vit_small as create_model
|
||||||
|
|
||||||
|
class ImagePredictor:
|
||||||
|
def __init__(self, model_path, class_indices_path, img_size=224):
|
||||||
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.img_size = img_size
|
||||||
|
self.data_transform = transforms.Compose([
|
||||||
|
transforms.Resize(int(self.img_size * 1.14)),
|
||||||
|
transforms.CenterCrop(self.img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
# Load class indices
|
||||||
|
with open(class_indices_path, "r",encoding="utf-8") as f:
|
||||||
|
self.class_indict = json.load(f)
|
||||||
|
# Load model
|
||||||
|
self.model = self.load_model(model_path)
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
|
||||||
|
model = create_model(num_classes=9).to(self.device)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict(self, cv2_image):
|
||||||
|
# Convert cv2 image to PIL image
|
||||||
|
image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
img = self.data_transform(image)
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# Predict class
|
||||||
|
with torch.no_grad():
|
||||||
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||||
|
probabilities = torch.softmax(output, dim=0)
|
||||||
|
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||||
|
|
||||||
|
# Predict class
|
||||||
|
with torch.no_grad():
|
||||||
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||||
|
probabilities = torch.softmax(output, dim=0)
|
||||||
|
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||||
|
|
||||||
|
# Top 1 result
|
||||||
|
result = {
|
||||||
|
"name": self.class_indict[str(top_catid[0].item())],
|
||||||
|
"score": top_prob[0].item(),
|
||||||
|
"label": top_catid[0].item()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Results dictionary
|
||||||
|
results = {"result": result, "log_id": str(uuid.uuid1())}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
||||||
|
# result = predictor.predict("../tulip.jpg")
|
||||||
|
# print(result)
|
|
@ -0,0 +1,135 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from my_dataset import MyDataSet
|
||||||
|
from model import mobile_vit_xx_small as create_model
|
||||||
|
from utils import read_split_data, train_one_epoch, evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if os.path.exists("./weights") is False:
|
||||||
|
os.makedirs("./weights")
|
||||||
|
|
||||||
|
tb_writer = SummaryWriter()
|
||||||
|
|
||||||
|
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
||||||
|
|
||||||
|
img_size = 224
|
||||||
|
data_transform = {
|
||||||
|
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
|
||||||
|
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
|
||||||
|
|
||||||
|
# 实例化训练数据集
|
||||||
|
train_dataset = MyDataSet(images_path=train_images_path,
|
||||||
|
images_class=train_images_label,
|
||||||
|
transform=data_transform["train"])
|
||||||
|
|
||||||
|
# 实例化验证数据集
|
||||||
|
val_dataset = MyDataSet(images_path=val_images_path,
|
||||||
|
images_class=val_images_label,
|
||||||
|
transform=data_transform["val"])
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||||
|
print('Using {} dataloader workers every process'.format(nw))
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=nw,
|
||||||
|
collate_fn=train_dataset.collate_fn)
|
||||||
|
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=nw,
|
||||||
|
collate_fn=val_dataset.collate_fn)
|
||||||
|
|
||||||
|
model = create_model(num_classes=args.num_classes).to(device)
|
||||||
|
|
||||||
|
if args.weights != "":
|
||||||
|
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
|
||||||
|
weights_dict = torch.load(args.weights, map_location=device)
|
||||||
|
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
|
||||||
|
# 删除有关分类类别的权重
|
||||||
|
for k in list(weights_dict.keys()):
|
||||||
|
if "classifier" in k:
|
||||||
|
del weights_dict[k]
|
||||||
|
print(model.load_state_dict(weights_dict, strict=False))
|
||||||
|
|
||||||
|
if args.freeze_layers:
|
||||||
|
for name, para in model.named_parameters():
|
||||||
|
# 除head外,其他权重全部冻结
|
||||||
|
if "classifier" not in name:
|
||||||
|
para.requires_grad_(False)
|
||||||
|
else:
|
||||||
|
print("training {}".format(name))
|
||||||
|
|
||||||
|
pg = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
|
||||||
|
|
||||||
|
best_acc = 0.
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
# train
|
||||||
|
train_loss, train_acc = train_one_epoch(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
data_loader=train_loader,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch)
|
||||||
|
|
||||||
|
# validate
|
||||||
|
val_loss, val_acc = evaluate(model=model,
|
||||||
|
data_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch)
|
||||||
|
|
||||||
|
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
||||||
|
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
||||||
|
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
||||||
|
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
||||||
|
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
||||||
|
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
||||||
|
|
||||||
|
if val_acc > best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
torch.save(model.state_dict(), "./weights/best_model.pth")
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), "./weights/latest_model.pth")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--num_classes', type=int, default=5)
|
||||||
|
parser.add_argument('--epochs', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0002)
|
||||||
|
|
||||||
|
# 数据集所在根目录
|
||||||
|
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||||||
|
parser.add_argument('--data-path', type=str,
|
||||||
|
default="/data/flower_photos")
|
||||||
|
|
||||||
|
# 预训练权重路径,如果不想载入就设置为空字符
|
||||||
|
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
|
||||||
|
help='initial weights path')
|
||||||
|
# 是否冻结权重
|
||||||
|
parser.add_argument('--freeze-layers', type=bool, default=False)
|
||||||
|
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
main(opt)
|
|
@ -0,0 +1,155 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
This layer applies a multi-head self- or cross-attention as described in
|
||||||
|
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||||
|
num_heads (int): Number of heads in multi-head attention
|
||||||
|
attn_dropout (float): Attention dropout. Default: 0.0
|
||||||
|
bias (bool): Use bias or not. Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||||
|
and :math:`C_{in}` is input embedding dim
|
||||||
|
- Output: same shape as the input
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if embed_dim % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
||||||
|
self.__class__.__name__, embed_dim, num_heads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||||
|
|
||||||
|
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||||
|
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def forward(self, x_q: Tensor) -> Tensor:
|
||||||
|
# [N, P, C]
|
||||||
|
b_sz, n_patches, in_channels = x_q.shape
|
||||||
|
|
||||||
|
# self-attention
|
||||||
|
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
||||||
|
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
||||||
|
|
||||||
|
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
||||||
|
qkv = qkv.transpose(1, 3).contiguous()
|
||||||
|
|
||||||
|
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
||||||
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||||
|
|
||||||
|
query = query * self.scaling
|
||||||
|
|
||||||
|
# [N h, P, c] -> [N, h, c, P]
|
||||||
|
key = key.transpose(-1, -2)
|
||||||
|
|
||||||
|
# QK^T
|
||||||
|
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
||||||
|
attn = torch.matmul(query, key)
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
# weighted sum
|
||||||
|
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
||||||
|
out = torch.matmul(attn, value)
|
||||||
|
|
||||||
|
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
||||||
|
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
||||||
|
out = self.out_proj(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
||||||
|
Args:
|
||||||
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||||
|
ffn_latent_dim (int): Inner dimension of the FFN
|
||||||
|
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
||||||
|
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||||
|
dropout (float): Dropout rate. Default: 0.0
|
||||||
|
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||||
|
and :math:`C_{in}` is input embedding dim
|
||||||
|
- Output: same shape as the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
ffn_latent_dim: int,
|
||||||
|
num_heads: Optional[int] = 8,
|
||||||
|
attn_dropout: Optional[float] = 0.0,
|
||||||
|
dropout: Optional[float] = 0.0,
|
||||||
|
ffn_dropout: Optional[float] = 0.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
attn_unit = MultiHeadAttention(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_mha = nn.Sequential(
|
||||||
|
nn.LayerNorm(embed_dim),
|
||||||
|
attn_unit,
|
||||||
|
nn.Dropout(p=dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_ffn = nn.Sequential(
|
||||||
|
nn.LayerNorm(embed_dim),
|
||||||
|
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=ffn_dropout),
|
||||||
|
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||||
|
nn.Dropout(p=dropout)
|
||||||
|
)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.ffn_dim = ffn_latent_dim
|
||||||
|
self.ffn_dropout = ffn_dropout
|
||||||
|
self.std_dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# multi-head attention
|
||||||
|
res = x
|
||||||
|
x = self.pre_norm_mha(x)
|
||||||
|
x = x + res
|
||||||
|
|
||||||
|
# feed forward network
|
||||||
|
x = x + self.pre_norm_ffn(x)
|
||||||
|
return x
|
|
@ -0,0 +1,56 @@
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
in_channels = 32
|
||||||
|
patch_h = 2
|
||||||
|
patch_w = 2
|
||||||
|
num_patch_h = 16
|
||||||
|
num_patch_w = 16
|
||||||
|
num_patches = num_patch_h * num_patch_w
|
||||||
|
patch_area = patch_h * patch_w
|
||||||
|
|
||||||
|
|
||||||
|
def official(x: torch.Tensor):
|
||||||
|
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def my_self(x: torch.Tensor):
|
||||||
|
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(3, 4)
|
||||||
|
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||||
|
print(torch.equal(official(t), my_self(t)))
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
official(t)
|
||||||
|
print(f"official time: {time.time() - t1}")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
my_self(t)
|
||||||
|
print(f"self time: {time.time() - t1}")
|
|
@ -0,0 +1,179 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def read_split_data(root: str, val_rate: float = 0.2):
|
||||||
|
random.seed(0) # 保证随机结果可复现
|
||||||
|
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||||||
|
|
||||||
|
# 遍历文件夹,一个文件夹对应一个类别
|
||||||
|
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||||||
|
# 排序,保证各平台顺序一致
|
||||||
|
flower_class.sort()
|
||||||
|
# 生成类别名称以及对应的数字索引
|
||||||
|
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||||||
|
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||||||
|
with open('class_indices.json', 'w') as json_file:
|
||||||
|
json_file.write(json_str)
|
||||||
|
|
||||||
|
train_images_path = [] # 存储训练集的所有图片路径
|
||||||
|
train_images_label = [] # 存储训练集图片对应索引信息
|
||||||
|
val_images_path = [] # 存储验证集的所有图片路径
|
||||||
|
val_images_label = [] # 存储验证集图片对应索引信息
|
||||||
|
every_class_num = [] # 存储每个类别的样本总数
|
||||||
|
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||||||
|
# 遍历每个文件夹下的文件
|
||||||
|
for cla in flower_class:
|
||||||
|
cla_path = os.path.join(root, cla)
|
||||||
|
# 遍历获取supported支持的所有文件路径
|
||||||
|
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||||||
|
if os.path.splitext(i)[-1] in supported]
|
||||||
|
# 排序,保证各平台顺序一致
|
||||||
|
images.sort()
|
||||||
|
# 获取该类别对应的索引
|
||||||
|
image_class = class_indices[cla]
|
||||||
|
# 记录该类别的样本数量
|
||||||
|
every_class_num.append(len(images))
|
||||||
|
# 按比例随机采样验证样本
|
||||||
|
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||||||
|
|
||||||
|
for img_path in images:
|
||||||
|
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||||||
|
val_images_path.append(img_path)
|
||||||
|
val_images_label.append(image_class)
|
||||||
|
else: # 否则存入训练集
|
||||||
|
train_images_path.append(img_path)
|
||||||
|
train_images_label.append(image_class)
|
||||||
|
|
||||||
|
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||||||
|
print("{} images for training.".format(len(train_images_path)))
|
||||||
|
print("{} images for validation.".format(len(val_images_path)))
|
||||||
|
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||||||
|
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||||||
|
|
||||||
|
plot_image = False
|
||||||
|
if plot_image:
|
||||||
|
# 绘制每种类别个数柱状图
|
||||||
|
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||||||
|
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||||||
|
plt.xticks(range(len(flower_class)), flower_class)
|
||||||
|
# 在柱状图上添加数值标签
|
||||||
|
for i, v in enumerate(every_class_num):
|
||||||
|
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||||
|
# 设置x坐标
|
||||||
|
plt.xlabel('image class')
|
||||||
|
# 设置y坐标
|
||||||
|
plt.ylabel('number of images')
|
||||||
|
# 设置柱状图的标题
|
||||||
|
plt.title('flower class distribution')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return train_images_path, train_images_label, val_images_path, val_images_label
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data_loader_image(data_loader):
|
||||||
|
batch_size = data_loader.batch_size
|
||||||
|
plot_num = min(batch_size, 4)
|
||||||
|
|
||||||
|
json_path = './class_indices.json'
|
||||||
|
assert os.path.exists(json_path), json_path + " does not exist."
|
||||||
|
json_file = open(json_path, 'r')
|
||||||
|
class_indices = json.load(json_file)
|
||||||
|
|
||||||
|
for data in data_loader:
|
||||||
|
images, labels = data
|
||||||
|
for i in range(plot_num):
|
||||||
|
# [C, H, W] -> [H, W, C]
|
||||||
|
img = images[i].numpy().transpose(1, 2, 0)
|
||||||
|
# 反Normalize操作
|
||||||
|
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||||||
|
label = labels[i].item()
|
||||||
|
plt.subplot(1, plot_num, i+1)
|
||||||
|
plt.xlabel(class_indices[str(label)])
|
||||||
|
plt.xticks([]) # 去掉x轴的刻度
|
||||||
|
plt.yticks([]) # 去掉y轴的刻度
|
||||||
|
plt.imshow(img.astype('uint8'))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def write_pickle(list_info: list, file_name: str):
|
||||||
|
with open(file_name, 'wb') as f:
|
||||||
|
pickle.dump(list_info, f)
|
||||||
|
|
||||||
|
|
||||||
|
def read_pickle(file_name: str) -> list:
|
||||||
|
with open(file_name, 'rb') as f:
|
||||||
|
info_list = pickle.load(f)
|
||||||
|
return info_list
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||||
|
model.train()
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sample_num = 0
|
||||||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||||
|
for step, data in enumerate(data_loader):
|
||||||
|
images, labels = data
|
||||||
|
sample_num += images.shape[0]
|
||||||
|
|
||||||
|
pred = model(images.to(device))
|
||||||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||||||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||||
|
|
||||||
|
loss = loss_function(pred, labels.to(device))
|
||||||
|
loss.backward()
|
||||||
|
accu_loss += loss.detach()
|
||||||
|
|
||||||
|
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||||
|
accu_loss.item() / (step + 1),
|
||||||
|
accu_num.item() / sample_num)
|
||||||
|
|
||||||
|
if not torch.isfinite(loss):
|
||||||
|
print('WARNING: non-finite loss, ending training ', loss)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(model, data_loader, device, epoch):
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||||
|
|
||||||
|
sample_num = 0
|
||||||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||||
|
for step, data in enumerate(data_loader):
|
||||||
|
images, labels = data
|
||||||
|
sample_num += images.shape[0]
|
||||||
|
|
||||||
|
pred = model(images.to(device))
|
||||||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||||||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||||
|
|
||||||
|
loss = loss_function(pred, labels.to(device))
|
||||||
|
accu_loss += loss
|
||||||
|
|
||||||
|
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||||
|
accu_loss.item() / (step + 1),
|
||||||
|
accu_num.item() / sample_num)
|
||||||
|
|
||||||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
|
@ -0,0 +1,84 @@
|
||||||
|
import cv2
|
||||||
|
import dlib
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from predict_api import ImagePredictor
|
||||||
|
|
||||||
|
|
||||||
|
def draw_chinese_text(image, text, position, color=(0, 255, 0)):
|
||||||
|
# Convert cv2 image to PIL image
|
||||||
|
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
# Create a blank image with alpha channel, same size as original image
|
||||||
|
blank = Image.new('RGBA', image_pil.size, (0, 0, 0, 0))
|
||||||
|
|
||||||
|
# Create a draw object and draw text on the blank image
|
||||||
|
draw = ImageDraw.Draw(blank)
|
||||||
|
font = ImageFont.truetype("simhei.ttf", 20)
|
||||||
|
draw.text(position, text, fill=color, font=font)
|
||||||
|
|
||||||
|
# Composite the original image with the blank image
|
||||||
|
image_pil = Image.alpha_composite(image_pil.convert('RGBA'), blank)
|
||||||
|
|
||||||
|
# Convert PIL image back to cv2 image
|
||||||
|
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize face detector
|
||||||
|
detector = dlib.get_frontal_face_detector()
|
||||||
|
|
||||||
|
# Initialize ImagePredictor
|
||||||
|
predictor = ImagePredictor(model_path="./best.pth", class_indices_path="./class_indices.json")
|
||||||
|
|
||||||
|
# Open the webcam
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Read a frame from the webcam
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert the frame to grayscale
|
||||||
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# Detect faces in the frame
|
||||||
|
faces = detector(gray)
|
||||||
|
|
||||||
|
for rect in faces:
|
||||||
|
# Get the coordinates of the face rectangle
|
||||||
|
x = rect.left()
|
||||||
|
y = rect.top()
|
||||||
|
w = rect.width()
|
||||||
|
h = rect.height()
|
||||||
|
|
||||||
|
# Crop the face from the frame
|
||||||
|
face = frame[y:y+h, x:x+w]
|
||||||
|
|
||||||
|
# Predict the emotion of the face
|
||||||
|
result = predictor.predict(face)
|
||||||
|
|
||||||
|
# Get the emotion with the highest score
|
||||||
|
emotion = result["result"]["name"]
|
||||||
|
|
||||||
|
# Draw the rectangle around the face
|
||||||
|
cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Put the emotion text above the rectangle cv2
|
||||||
|
# cv2.putText(frame, emotion, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Put the emotion text above the rectangle PIL
|
||||||
|
frame = draw_chinese_text(frame, emotion, (x, y))
|
||||||
|
|
||||||
|
# Display the frame
|
||||||
|
cv2.imshow("Emotion Recognition", frame)
|
||||||
|
|
||||||
|
# Break the loop if 'q' is pressed
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Release the webcam and destroy all windows
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
|
@ -0,0 +1,157 @@
|
||||||
|
import dlib
|
||||||
|
import numpy as np
|
||||||
|
import scipy.fftpack as fftpack
|
||||||
|
from sklearn.decomposition import FastICA
|
||||||
|
import cv2
|
||||||
|
from scipy import signal
|
||||||
|
|
||||||
|
|
||||||
|
class HeartRateMonitor:
|
||||||
|
def __init__(self, fps, freqs_min, freqs_max):
|
||||||
|
self.fps = fps
|
||||||
|
self.freqs_min = freqs_min
|
||||||
|
self.freqs_max = freqs_max
|
||||||
|
self.all_hr_values = []
|
||||||
|
|
||||||
|
def get_channel_signal(self, ROI):
|
||||||
|
blue = []
|
||||||
|
green = []
|
||||||
|
red = []
|
||||||
|
for roi in ROI:
|
||||||
|
b, g, r = cv2.split(roi)
|
||||||
|
b = np.mean(np.sum(b)) / np.std(b)
|
||||||
|
g = np.mean(np.sum(g)) / np.std(g)
|
||||||
|
r = np.mean(np.sum(r)) / np.std(r)
|
||||||
|
blue.append(b)
|
||||||
|
green.append(g)
|
||||||
|
red.append(r)
|
||||||
|
return blue, green, red
|
||||||
|
|
||||||
|
def ICA(self, matrix, n_component, max_iter=200):
|
||||||
|
matrix = matrix.T
|
||||||
|
ica = FastICA(n_components=n_component, max_iter=max_iter)
|
||||||
|
u = ica.fit_transform(matrix)
|
||||||
|
return u.T
|
||||||
|
|
||||||
|
def fft_filter(self, signal):
|
||||||
|
fft = fftpack.fft(signal, axis=0)
|
||||||
|
frequencies = fftpack.fftfreq(signal.shape[0], d=1.0 / self.fps)
|
||||||
|
bound_low = (np.abs(frequencies - self.freqs_min)).argmin()
|
||||||
|
bound_high = (np.abs(frequencies - self.freqs_max)).argmin()
|
||||||
|
fft[:bound_low] = 0
|
||||||
|
fft[bound_high:-bound_high] = 0
|
||||||
|
fft[-bound_low:] = 0
|
||||||
|
return fft, frequencies
|
||||||
|
|
||||||
|
def find_heart_rate(self, fft, freqs):
|
||||||
|
fft_maximums = []
|
||||||
|
|
||||||
|
for i in range(fft.shape[0]):
|
||||||
|
if self.freqs_min <= freqs[i] <= self.freqs_max:
|
||||||
|
fftMap = abs(fft[i])
|
||||||
|
fft_maximums.append(fftMap.max())
|
||||||
|
else:
|
||||||
|
fft_maximums.append(0)
|
||||||
|
|
||||||
|
peaks, properties = signal.find_peaks(fft_maximums)
|
||||||
|
max_peak = -1
|
||||||
|
max_freq = 0
|
||||||
|
|
||||||
|
for peak in peaks:
|
||||||
|
if fft_maximums[peak] > max_freq:
|
||||||
|
max_freq = fft_maximums[peak]
|
||||||
|
max_peak = peak
|
||||||
|
|
||||||
|
return freqs[max_peak] * 60
|
||||||
|
|
||||||
|
def fourier_transform(self, signal, N, fs):
|
||||||
|
result = fftpack.fft(signal, N)
|
||||||
|
result = np.abs(result)
|
||||||
|
freqs = np.arange(N) / N
|
||||||
|
freqs = freqs * fs
|
||||||
|
return result[:N // 2], freqs[:N // 2]
|
||||||
|
|
||||||
|
def calculate_hrv(self, hr_values, window_size=5):
|
||||||
|
num_values = int(window_size * self.fps)
|
||||||
|
start_idx = max(0, len(hr_values) - num_values)
|
||||||
|
recent_hr_values = hr_values[start_idx:]
|
||||||
|
rr_intervals = np.array(recent_hr_values)
|
||||||
|
|
||||||
|
# 计算SDNN
|
||||||
|
sdnn = np.std(rr_intervals)
|
||||||
|
|
||||||
|
# 计算RMSSD
|
||||||
|
nn_diffs = np.diff(rr_intervals)
|
||||||
|
rmssd = np.sqrt(np.mean(nn_diffs ** 2))
|
||||||
|
|
||||||
|
# 计算CV R-R
|
||||||
|
mean_rr = np.mean(rr_intervals)
|
||||||
|
cv_rr = sdnn / mean_rr if mean_rr != 0 else 0
|
||||||
|
|
||||||
|
return sdnn, rmssd, cv_rr
|
||||||
|
|
||||||
|
def process_roi(self, ROI):
|
||||||
|
blue, green, red = self.get_channel_signal(ROI)
|
||||||
|
matrix = np.array([blue, green, red])
|
||||||
|
component = self.ICA(matrix, 3)
|
||||||
|
hr_values = []
|
||||||
|
for i in range(3):
|
||||||
|
fft, freqs = self.fft_filter(component[i])
|
||||||
|
heartrate = self.find_heart_rate(fft, freqs)
|
||||||
|
hr_values.append(heartrate)
|
||||||
|
avg_hr = sum(hr_values) / 3
|
||||||
|
self.all_hr_values.append(avg_hr)
|
||||||
|
sdnn, rmssd, cv_rr = self.calculate_hrv(self.all_hr_values, window_size=5)
|
||||||
|
return avg_hr, sdnn, rmssd, cv_rr
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
ROI = []
|
||||||
|
|
||||||
|
freqs_min = 0.8
|
||||||
|
freqs_max = 1.8
|
||||||
|
heartrate = 0
|
||||||
|
sdnn, rmssd, cv_rr = 0, 0, 0
|
||||||
|
camera_code = 0
|
||||||
|
capture = cv2.VideoCapture(camera_code)
|
||||||
|
fps = capture.get(cv2.CAP_PROP_FPS)
|
||||||
|
|
||||||
|
hr_monitor = HeartRateMonitor(fps, freqs_min, freqs_max)
|
||||||
|
|
||||||
|
detector = dlib.get_frontal_face_detector()
|
||||||
|
while capture.isOpened():
|
||||||
|
ret, frame = capture.read()
|
||||||
|
if not ret:
|
||||||
|
continue
|
||||||
|
dects = detector(frame)
|
||||||
|
for face in dects:
|
||||||
|
left = face.left()
|
||||||
|
right = face.right()
|
||||||
|
top = face.top()
|
||||||
|
bottom = face.bottom()
|
||||||
|
|
||||||
|
h = bottom - top
|
||||||
|
w = right - left
|
||||||
|
roi = frame[top + h // 10 * 2:top + h // 10 * 7, left + w // 9 * 2:left + w // 9 * 8]
|
||||||
|
|
||||||
|
cv2.rectangle(frame, (left + w // 9 * 2, top + h // 10 * 2), (left + w // 9 * 8, top + h // 10 * 7),
|
||||||
|
color=(0, 0, 255))
|
||||||
|
cv2.rectangle(frame, (left, top), (left + w, top + h), color=(0, 0, 255))
|
||||||
|
ROI.append(roi)
|
||||||
|
if len(ROI) == 300:
|
||||||
|
heartrate, sdnn, rmssd, cv_rr = hr_monitor.process_roi(ROI)
|
||||||
|
for i in range(30):
|
||||||
|
ROI.pop(0)
|
||||||
|
cv2.putText(frame, '{:.1f}bps, CV R-R: {:.2f}'.format(heartrate, cv_rr), (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2,
|
||||||
|
(255, 0, 255), 2)
|
||||||
|
cv2.putText(frame, 'SDNN: {:.2f}, RMSSD: {:.2f}'.format(sdnn, rmssd), (50, 80),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 1,
|
||||||
|
(255, 0, 255), 2)
|
||||||
|
cv2.imshow('frame', frame)
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
capture.release()
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Video-based Heart Rate Monitoring
|
||||||
|
|
||||||
|
这个项目是一个基于视频的心率监测系统。它使用计算机视觉技术从人脸视频中提取心率信息。主要功能包括:
|
||||||
|
|
||||||
|
1. 检测人脸区域
|
||||||
|
2. 从人脸区域提取RGB彩色通道信号
|
||||||
|
3. 使用独立分量分析(ICA)从RGB信号中提取心率相关信号
|
||||||
|
4. 使用FFT对信号进行频率分析,找出相应的心率值
|
||||||
|
5. 计算心率变异性(HRV)指标,如SDNN、RMSSD和CV R-R
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
- `HeartRateMonitor.py`: 实现心率监测算法的核心逻辑,以及算法演示程序。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 确保已安装所需的Python库,包括`opencv-python`、`dlib`、`numpy`、`scipy`和`scikit-learn`
|
||||||
|
2. 运行`HeartRateMonitor.py`脚本
|
||||||
|
3. 脚本将打开默认摄像头,检测人脸区域
|
||||||
|
4. 从人脸区域提取RGB彩色通道信号,使用ICA分离出心率信号
|
||||||
|
5. 使用FFT分析心率信号,计算当前心率值
|
||||||
|
6. 同时计算心率变异性指标SDNN、RMSSD和CV R-R
|
||||||
|
7. 在视频画面上显示心率值和HRV指标
|
||||||
|
|
||||||
|
## 算法原理
|
||||||
|
|
||||||
|
### 心率信号提取
|
||||||
|
|
||||||
|
1. 从人脸ROI区域提取RGB三个通道的平均值和标准差
|
||||||
|
2. 将RGB三个通道作为特征矩阵的三行输入ICA算法
|
||||||
|
3. ICA算法将特征矩阵分解为3个独立分量
|
||||||
|
4. 选择其中一个独立分量作为心率信号
|
||||||
|
|
||||||
|
### 心率计算
|
||||||
|
|
||||||
|
1. 对心率信号进行FFT变换得到频率域表示
|
||||||
|
2. 根据设定的有效心率频率范围过滤FFT结果
|
||||||
|
3. 在过滤后的FFT结果中找到最大值对应的频率,即为当前心率值(bpm)
|
||||||
|
|
||||||
|
### 心率变异性指标
|
||||||
|
|
||||||
|
1. 使用滑动窗口从最近的心率值序列中提取一段心率数据
|
||||||
|
2. 计算该段数据的SDNN(标准差)、RMSSD(连续差分平方根值的均值)和CV R-R(R-R间期变异系数)
|
||||||
|
3. 以上三个指标反映了心率的变异程度
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
- `freqs_min`: 有效心率频率的下限(Hz)
|
||||||
|
- `freqs_max`: 有效心率频率的上限(Hz)
|
||||||
|
- `camera_code`: 使用的摄像头编号,0为默认摄像头
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
- 算法依赖人脸检测,如果人脸被遮挡或角度过大,将影响心率测量的准确性
|
||||||
|
- 在光照条件较差的环境下,也可能影响测量精度
|
||||||
|
- 目前只支持单个人脸的心率检测,多人情况下需要进一步改进
|
||||||
|
- 算法的鲁棒性还有待提高,在特殊情况下可能会出现失效或测量偏差
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Video-based Respiration Rate Detection Algorithm
|
||||||
|
|
||||||
|
该项目是一个基于视频图像的呼吸频率检测算法的实现。它可以从视频中提取人体的呼吸曲线并计算呼吸频率。该算法使用了光流法、相关性引导的光流法、滤波、归一化等技术来提高检测精度。同时,它提供了多种呼吸频率计算方法供选择,包括FFT、Peak Counting、Crossing Point和Negative Feedback Crossover Point等。
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
- `params.py`: 包含所有可配置的参数及其默认值。
|
||||||
|
- `RespirationRateDetector.py`: 实现了呼吸频率检测算法的核心逻辑。
|
||||||
|
- `demo.py`: 演示程序,从摄像头读取视频流并实时显示呼吸曲线和呼吸频率。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 克隆该项目到本地。
|
||||||
|
2. 安装所需的Python依赖包,OpenCV、NumPy、SciPy、Matplotlib。
|
||||||
|
3. 根据需要在`params.py`中调整参数设置。
|
||||||
|
4. 运行`demo.py`启动演示程序。
|
||||||
|
|
||||||
|
程序将打开一个窗口显示从摄像头捕获的视频流,并在另一个窗口中绘制实时呼吸曲线。同时,它还会在视频窗口上显示使用不同方法计算得到的呼吸频率值。
|
||||||
|
|
||||||
|
## 核心算法
|
||||||
|
|
||||||
|
该算法的核心步骤包括:
|
||||||
|
|
||||||
|
1. **光流法**:使用光流法跟踪视频中的特征点,并计算这些特征点的运动幅度和。
|
||||||
|
2. **相关性引导的光流法**:通过计算每个特征点与呼吸曲线的相关性,筛选出与呼吸相关的特征点,以提高检测精度。
|
||||||
|
3. **滤波**:对原始呼吸曲线进行带通滤波,去除高频和低频噪声。
|
||||||
|
4. **归一化**:将滤波后的呼吸曲线进行归一化处理。
|
||||||
|
5. **呼吸频率计算**:使用FFT、Peak Counting、Crossing Point和Negative Feedback Crossover Point等多种方法计算呼吸频率。
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
`params.py`中包含了该算法的所有可配置参数及其默认值。主要参数包括:
|
||||||
|
|
||||||
|
- `--video-path`: 输入视频文件的路径。默认值为'./1.mp4'。
|
||||||
|
|
||||||
|
- `--FSS`: 是否启用特征点选择策略(Feature Point Selection Strategy)。默认为True。
|
||||||
|
- `--CGOF`: 是否启用相关性引导的光流法(Correlation-Guided Optical Flow Method)。默认为True。
|
||||||
|
- `--filter`: 是否对呼吸曲线进行滤波。默认为True。
|
||||||
|
- `--Normalization`: 是否对呼吸曲线进行归一化。默认为True。
|
||||||
|
- `--RR_Evaluation`: 是否计算呼吸频率。默认为True。
|
||||||
|
|
||||||
|
其他参数控制光流法、特征点选择策略、滤波和呼吸频率计算的具体设置。
|
||||||
|
|
||||||
|
- `--OFP-maxCorners`: 光流法中检测特征点的最大数量。默认为100。
|
||||||
|
- `--OFP-qualityLevel`: 光流法中特征点检测的质量等级。默认为0.1。
|
||||||
|
- `--OFP-minDistance`: 光流法中特征点之间的最小距离。默认为7。
|
||||||
|
- `--OFP-mask`: 光流法中使用的mask,用于指定感兴趣区域。默认为None。
|
||||||
|
- `--OFP-QualityLevelRV`: 当无法检测到足够数量的特征点时,降低质量等级的步长值。默认为0.05。
|
||||||
|
- `--OFP-winSize`: 光流法中金字塔Lucas-Kanade光流估计器的窗口大小。默认为(15,15)。
|
||||||
|
- `--OFP-maxLevel`: 光流法中的金字塔层数。默认为2。
|
||||||
|
|
||||||
|
- `--FSS-switch`: 是否启用特征点选择策略。
|
||||||
|
- `--FSS-maxCorners`: 特征点选择策略中检测特征点的最大数量。默认为100。
|
||||||
|
- `--FSS-qualityLevel`: 特征点选择策略中特征点检测的质量等级。默认为0.1。
|
||||||
|
- `--FSS-minDistance`: 特征点选择策略中特征点之间的最小距离。默认为7。
|
||||||
|
- `--FSS-mask`: 特征点选择策略中使用的mask。默认为None。
|
||||||
|
- `--FSS-QualityLevelRV`: 当无法检测到足够数量的特征点时,降低质量等级的步长值。默认为0.05。
|
||||||
|
- `--FSS-FPN`: 特征点选择策略中要选择的特征点数量。默认为5。
|
||||||
|
|
||||||
|
- `--CGOF-switch`: 是否启用相关性引导的光流法。
|
||||||
|
|
||||||
|
- `--Filter-switch`: 是否对呼吸曲线进行滤波。
|
||||||
|
- `--Filter-type`: 滤波器的类型,可选'lowpass'、'highpass'、'bandpass'和'bandstop'。默认为'bandpass'。
|
||||||
|
- `--Filter-order`: 滤波器的阶数。默认为3。
|
||||||
|
- `--Filter-LowPass`: 带通滤波器的低通频率(次/分钟)。默认为2。
|
||||||
|
- `--Filter-HighPass`: 带通滤波器的高通频率(次/分钟)。默认为40。
|
||||||
|
|
||||||
|
- `--Normalization-switch`: 是否对呼吸曲线进行归一化。
|
||||||
|
|
||||||
|
- `--RR-switch`: 是否计算呼吸频率。
|
||||||
|
|
||||||
|
- `--RR-Algorithm-PC-Height`: Peak Counting算法中使用的峰值高度阈值。默认为None。
|
||||||
|
- `--RR-Algorithm-PC-Threshold`: Peak Counting算法中使用的峰值门限。默认为None。
|
||||||
|
- `--RR-Algorithm-PC-MaxRR`: Peak Counting算法中呼吸频率的最大值(次/分钟)。默认为45。
|
||||||
|
- `--RR-Algorithm-CP-shfit_distance`: Crossing Point算法中使用的移位距离。默认为15。
|
||||||
|
- `--RR-Algorithm-NFCP-shfit_distance`: Negative Feedback Crossover Point算法中使用的移位距离。默认为15。
|
||||||
|
- `--RR-Algorithm-NFCP-qualityLevel`: Negative Feedback Crossover Point算法中使用的质量等级。默认为0.6。
|
|
@ -0,0 +1,233 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from scipy import signal
|
||||||
|
from scipy.fftpack import fft
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
|
||||||
|
|
||||||
|
class RespirationRateDetector:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def FeaturePointSelectionStrategy(self, Image, FPN=5, QualityLevel=0.3):
|
||||||
|
Image_gray = Image
|
||||||
|
feature_params = dict(maxCorners=self.args.FSS_maxCorners,
|
||||||
|
qualityLevel=QualityLevel,
|
||||||
|
minDistance=self.args.FSS_minDistance)
|
||||||
|
|
||||||
|
p0 = cv2.goodFeaturesToTrack(Image_gray, mask=self.args.FSS_mask, **feature_params)
|
||||||
|
|
||||||
|
""" Robust checking """
|
||||||
|
while (p0 is None):
|
||||||
|
QualityLevel = QualityLevel - self.args.FSS_QualityLevelRV
|
||||||
|
feature_params = dict(maxCorners=self.args.FSS_maxCorners,
|
||||||
|
qualityLevel=QualityLevel,
|
||||||
|
minDistance=self.args.FSS_minDistance)
|
||||||
|
p0 = cv2.goodFeaturesToTrack(Image_gray, mask=None, **feature_params)
|
||||||
|
|
||||||
|
if len(p0) < FPN:
|
||||||
|
FPN = len(p0)
|
||||||
|
|
||||||
|
h = Image_gray.shape[0] / 2
|
||||||
|
w = Image_gray.shape[1] / 2
|
||||||
|
|
||||||
|
p1 = p0.copy()
|
||||||
|
p1[:, :, 0] -= w
|
||||||
|
p1[:, :, 1] -= h
|
||||||
|
p1_1 = np.multiply(p1, p1)
|
||||||
|
p1_2 = np.sum(p1_1, 2)
|
||||||
|
p1_3 = np.sqrt(p1_2)
|
||||||
|
p1_4 = p1_3[:, 0]
|
||||||
|
p1_5 = np.argsort(p1_4)
|
||||||
|
|
||||||
|
FPMap = np.zeros((FPN, 1, 2), dtype=np.float32)
|
||||||
|
for i in range(FPN):
|
||||||
|
FPMap[i, :, :] = p0[p1_5[i], :, :]
|
||||||
|
|
||||||
|
return FPMap
|
||||||
|
|
||||||
|
def CorrelationGuidedOpticalFlowMethod(self, FeatureMtx_Amp, RespCurve):
|
||||||
|
CGAmp_Mtx = FeatureMtx_Amp.T
|
||||||
|
CGAmpAugmented_Mtx = np.zeros((CGAmp_Mtx.shape[0] + 1, CGAmp_Mtx.shape[1]))
|
||||||
|
CGAmpAugmented_Mtx[0, :] = RespCurve
|
||||||
|
CGAmpAugmented_Mtx[1:, :] = CGAmp_Mtx
|
||||||
|
|
||||||
|
Correlation_Mtx = np.corrcoef(CGAmpAugmented_Mtx)
|
||||||
|
CM_mean = np.mean(abs(Correlation_Mtx[0, 1:]))
|
||||||
|
Quality_num = (abs(Correlation_Mtx[0, 1:]) >= CM_mean).sum()
|
||||||
|
QualityFeaturePoint_arg = (abs(Correlation_Mtx[0, 1:]) >= CM_mean).argsort()[0 - Quality_num:]
|
||||||
|
|
||||||
|
CGOF_Mtx = np.zeros((FeatureMtx_Amp.shape[0], Quality_num))
|
||||||
|
|
||||||
|
for i in range(Quality_num):
|
||||||
|
CGOF_Mtx[:, i] = FeatureMtx_Amp[:, QualityFeaturePoint_arg[i]]
|
||||||
|
|
||||||
|
CGOF_Mtx_RespCurve = np.sum(CGOF_Mtx, 1) / Quality_num
|
||||||
|
|
||||||
|
return CGOF_Mtx_RespCurve
|
||||||
|
|
||||||
|
def ImproveOpticalFlow(self, frames, fs):
|
||||||
|
feature_params = dict(maxCorners=self.args.OFP_maxCorners,
|
||||||
|
qualityLevel=self.args.OFP_qualityLevel,
|
||||||
|
minDistance=self.args.OFP_minDistance)
|
||||||
|
|
||||||
|
old_frame = frames[0]
|
||||||
|
old_gray = cv2.cvtColor(old_frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
p0 = cv2.goodFeaturesToTrack(old_gray, mask=self.args.OFP_mask, **feature_params)
|
||||||
|
|
||||||
|
""" Robust Checking """
|
||||||
|
while (p0 is None):
|
||||||
|
self.args.OFP_qualityLevel = self.args.OFP_qualityLevel - self.args.OFP_QualityLevelRV
|
||||||
|
feature_params = dict(maxCorners=self.args.OFP_maxCorners,
|
||||||
|
qualityLevel=self.args.OFP_qualityLevel,
|
||||||
|
minDistance=self.args.OFP_minDistance)
|
||||||
|
p0 = cv2.goodFeaturesToTrack(old_gray, mask=None, **feature_params)
|
||||||
|
|
||||||
|
""" FeaturePoint Selection Strategy """
|
||||||
|
if self.args.FSS:
|
||||||
|
p0 = self.FeaturePointSelectionStrategy(Image=old_gray, FPN=self.args.FSS_FPN,
|
||||||
|
QualityLevel=self.args.FSS_qualityLevel)
|
||||||
|
else:
|
||||||
|
p0 = cv2.goodFeaturesToTrack(old_gray, mask=None, **feature_params)
|
||||||
|
|
||||||
|
lk_params = dict(winSize=self.args.OFP_winSize, maxLevel=self.args.OFP_maxLevel)
|
||||||
|
total_frame = len(frames)
|
||||||
|
|
||||||
|
FeatureMtx = np.zeros((total_frame, p0.shape[0], 2))
|
||||||
|
FeatureMtx[0, :, 0] = p0[:, 0, 0].T
|
||||||
|
FeatureMtx[0, :, 1] = p0[:, 0, 1].T
|
||||||
|
frame_num = 1
|
||||||
|
|
||||||
|
while (frame_num < total_frame):
|
||||||
|
frame_num += 1
|
||||||
|
frame = frames[frame_num - 1]
|
||||||
|
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
pl, st, err = cv2.calcOpticalFlowPyrLK(old_gray, frame_gray, p0, None, **lk_params)
|
||||||
|
|
||||||
|
old_gray = frame_gray.copy()
|
||||||
|
p0 = pl.reshape(-1, 1, 2)
|
||||||
|
FeatureMtx[frame_num - 1, :, 0] = p0[:, 0, 0].T
|
||||||
|
FeatureMtx[frame_num - 1, :, 1] = p0[:, 0, 1].T
|
||||||
|
|
||||||
|
FeatureMtx_Amp = np.sqrt(FeatureMtx[:, :, 0] ** 2 + FeatureMtx[:, :, 1] ** 2)
|
||||||
|
RespCurve = np.sum(FeatureMtx_Amp, 1) / p0.shape[0]
|
||||||
|
|
||||||
|
""" CCorrelation-Guided Optical Flow Method """
|
||||||
|
if self.args.CGOF:
|
||||||
|
RespCurve = self.CorrelationGuidedOpticalFlowMethod(FeatureMtx_Amp, RespCurve)
|
||||||
|
|
||||||
|
"""" Filter """
|
||||||
|
if self.args.filter:
|
||||||
|
original_signal = RespCurve
|
||||||
|
#
|
||||||
|
filter_order = self.args.Filter_order
|
||||||
|
LowPass = self.args.Filter_LowPass / 60
|
||||||
|
HighPass = self.args.Filter_HighPass / 60
|
||||||
|
b, a = signal.butter(filter_order, [2 * LowPass / fs, 2 * HighPass / fs], self.args.Filter_type)
|
||||||
|
filtedResp = signal.filtfilt(b, a, original_signal)
|
||||||
|
else:
|
||||||
|
filtedResp = RespCurve
|
||||||
|
|
||||||
|
""" Normalization """
|
||||||
|
if self.args.Normalization:
|
||||||
|
Resp_max = max(filtedResp)
|
||||||
|
Resp_min = min(filtedResp)
|
||||||
|
|
||||||
|
Resp_norm = (filtedResp - Resp_min) / (Resp_max - Resp_min) - 0.5
|
||||||
|
else:
|
||||||
|
Resp_norm = filtedResp
|
||||||
|
|
||||||
|
return 1 - Resp_norm
|
||||||
|
|
||||||
|
def FFT(self, data, fs):
|
||||||
|
fft_y = fft(data)
|
||||||
|
maxFrequency = fs
|
||||||
|
f = np.linspace(0, maxFrequency, len(data))
|
||||||
|
abs_y = np.abs(fft_y)
|
||||||
|
normalization_y = abs_y / len(data)
|
||||||
|
normalization_half_y = normalization_y[range(int(len(data) / 2))]
|
||||||
|
sorted_indices = np.argsort(normalization_half_y)
|
||||||
|
RR = f[sorted_indices[-2]] * 60
|
||||||
|
return RR
|
||||||
|
|
||||||
|
def PeakCounting(self, data, fs, Height=0.1, Threshold=0.2, MaxRR=30):
|
||||||
|
Distance = 60 / MaxRR * fs
|
||||||
|
peaks, _ = find_peaks(data, height=Height, threshold=Threshold, distance=Distance)
|
||||||
|
RR = len(peaks) / (len(data) / fs) * 60
|
||||||
|
return RR
|
||||||
|
|
||||||
|
def CrossingPoint(self, data, fs):
|
||||||
|
shfit_distance = int(fs / 2)
|
||||||
|
data_shift = np.zeros(data.shape) - 1
|
||||||
|
data_shift[shfit_distance:] = data[:-shfit_distance]
|
||||||
|
cross_curve = data - data_shift
|
||||||
|
|
||||||
|
zero_number = 0
|
||||||
|
zero_index = []
|
||||||
|
for i in range(len(cross_curve) - 1):
|
||||||
|
if cross_curve[i] == 0:
|
||||||
|
zero_number += 1
|
||||||
|
zero_index.append(i)
|
||||||
|
else:
|
||||||
|
if cross_curve[i] * cross_curve[i + 1] < 0:
|
||||||
|
zero_number += 1
|
||||||
|
zero_index.append(i)
|
||||||
|
|
||||||
|
cw = zero_number
|
||||||
|
N = len(data)
|
||||||
|
RR1 = ((cw / 2) / (N / fs)) * 60
|
||||||
|
|
||||||
|
return RR1
|
||||||
|
|
||||||
|
def NegativeFeedbackCrossoverPointMethod(self, data, fs, QualityLevel=0.2):
|
||||||
|
shfit_distance = int(fs / 2)
|
||||||
|
data_shift = np.zeros(data.shape) - 1
|
||||||
|
data_shift[shfit_distance:] = data[:-shfit_distance]
|
||||||
|
cross_curve = data - data_shift
|
||||||
|
|
||||||
|
zero_number = 0
|
||||||
|
zero_index = []
|
||||||
|
for i in range(len(cross_curve) - 1):
|
||||||
|
if cross_curve[i] == 0:
|
||||||
|
zero_number += 1
|
||||||
|
zero_index.append(i)
|
||||||
|
else:
|
||||||
|
if cross_curve[i] * cross_curve[i + 1] < 0:
|
||||||
|
zero_number += 1
|
||||||
|
zero_index.append(i)
|
||||||
|
|
||||||
|
cw = zero_number
|
||||||
|
N = len(data)
|
||||||
|
RR1 = ((cw / 2) / (N / fs)) * 60
|
||||||
|
|
||||||
|
if (len(zero_index) <= 1):
|
||||||
|
RR2 = RR1
|
||||||
|
else:
|
||||||
|
time_span = 60 / RR1 / 2 * fs * QualityLevel
|
||||||
|
zero_span = []
|
||||||
|
for i in range(len(zero_index) - 1):
|
||||||
|
zero_span.append(zero_index[i + 1] - zero_index[i])
|
||||||
|
|
||||||
|
while (min(zero_span) < time_span):
|
||||||
|
doubt_point = np.argmin(zero_span)
|
||||||
|
zero_index.pop(doubt_point)
|
||||||
|
zero_index.pop(doubt_point)
|
||||||
|
if len(zero_index) <= 1:
|
||||||
|
break
|
||||||
|
zero_span = []
|
||||||
|
for i in range(len(zero_index) - 1):
|
||||||
|
zero_span.append(zero_index[i + 1] - zero_index[i])
|
||||||
|
|
||||||
|
zero_number = len(zero_index)
|
||||||
|
cw = zero_number
|
||||||
|
RR2 = ((cw / 2) / (N / fs)) * 60
|
||||||
|
|
||||||
|
return RR2
|
||||||
|
|
||||||
|
def detect_respiration_rate(self, frames, fs):
|
||||||
|
resp_curve = self.ImproveOpticalFlow(frames, fs)
|
||||||
|
RR_FFT = self.FFT(resp_curve, fs)
|
||||||
|
RR_PC = self.PeakCounting(resp_curve, fs)
|
||||||
|
RR_CP = self.CrossingPoint(resp_curve, fs)
|
||||||
|
RR_NFCP = self.NegativeFeedbackCrossoverPointMethod(resp_curve, fs)
|
||||||
|
return resp_curve, RR_FFT, RR_PC, RR_CP, RR_NFCP
|
|
@ -0,0 +1,104 @@
|
||||||
|
import queue
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from RespirationRateDetector import RespirationRateDetector
|
||||||
|
from params import args
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cap = cv2.VideoCapture(0) # 使用摄像头
|
||||||
|
video_fs = cap.get(5)
|
||||||
|
|
||||||
|
detector = RespirationRateDetector(args)
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
text = ["calculating..."]
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
# face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||||
|
|
||||||
|
resps = queue.Queue()
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
line, = ax.plot([], [])
|
||||||
|
plt.ion()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
xdata = []
|
||||||
|
ydata = []
|
||||||
|
|
||||||
|
last = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
if len(frames) == 300:
|
||||||
|
|
||||||
|
Resp, RR_FFT, RR_PC, RR_CP, RR_NFCP = detector.detect_respiration_rate(frames, video_fs)
|
||||||
|
|
||||||
|
Resp[0] = last
|
||||||
|
|
||||||
|
for res in Resp:
|
||||||
|
resps.put(res)
|
||||||
|
|
||||||
|
last = Resp[-1]
|
||||||
|
|
||||||
|
text.clear()
|
||||||
|
text.append('RR-FFT: {:.2f} bpm'.format(RR_FFT))
|
||||||
|
text.append('RR-PC: {:.2f} bpm'.format(RR_PC))
|
||||||
|
text.append('RR-CP: {:.2f} bpm'.format(RR_CP))
|
||||||
|
text.append('RR-NFCP: {:.2f} bpm'.format(RR_NFCP))
|
||||||
|
frames = []
|
||||||
|
# 去除列表最前面的100个元素
|
||||||
|
# frames=frames[50:]
|
||||||
|
|
||||||
|
if not resps.empty():
|
||||||
|
|
||||||
|
resp = resps.get()
|
||||||
|
# 更新线的数据
|
||||||
|
ydata.append(resp)
|
||||||
|
|
||||||
|
else:
|
||||||
|
ydata.append(0)
|
||||||
|
|
||||||
|
if len(xdata) == 0:
|
||||||
|
xdata.append(1)
|
||||||
|
else:
|
||||||
|
xdata.append(xdata[-1] + 1)
|
||||||
|
|
||||||
|
if len(xdata) > 600:
|
||||||
|
xdata.pop(0)
|
||||||
|
ydata.pop(0)
|
||||||
|
|
||||||
|
# 生成时间序列
|
||||||
|
t = np.linspace(xdata[0] / video_fs, xdata[-1] / video_fs, len(ydata))
|
||||||
|
|
||||||
|
line.set_data(t, ydata) # 使用时间序列作为x轴
|
||||||
|
|
||||||
|
# 更新坐标轴的范围
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
|
||||||
|
ax.set_ylim(min(0, min(ydata)) - 0.5 * abs(min(ydata)), 1.5 * max(ydata))
|
||||||
|
# 更新图表的显示
|
||||||
|
plt.draw()
|
||||||
|
plt.pause(0.01)
|
||||||
|
|
||||||
|
for i, t in enumerate(text):
|
||||||
|
cv2.putText(frame, t, (10, 60 + i * 20), font, 0.6, (0, 255, 0), 2)
|
||||||
|
cv2.imshow('Respiration Rate Detection', frame)
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
if key == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,55 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser('Lightweight Video-based Respiration Rate Detection Algorithm script', add_help=False)
|
||||||
|
parser.add_argument('--video-path', default='./1.mp4', help='Video input path')
|
||||||
|
|
||||||
|
parser.add_argument('--FSS', default=True, type=bool, help='')
|
||||||
|
parser.add_argument('--CGOF', default=True, type=bool, help='')
|
||||||
|
parser.add_argument('--filter', default=True, type=bool, help='')
|
||||||
|
parser.add_argument('--Normalization', default=True, type=bool, help='')
|
||||||
|
parser.add_argument('--RR_Evaluation', default=True, type=bool, help='')
|
||||||
|
|
||||||
|
# # Optical flow parameters
|
||||||
|
parser.add_argument('--OFP-maxCorners', default=100, type=int, help='')
|
||||||
|
parser.add_argument('--OFP-qualityLevel', default=0.1, type=float, help='')
|
||||||
|
parser.add_argument('--OFP-minDistance', default=7, type=int, help='')
|
||||||
|
parser.add_argument('--OFP-mask', default=None, help='')
|
||||||
|
parser.add_argument('--OFP-QualityLevelRV', default=0.05, type=float, help='QualityLeve reduction value')
|
||||||
|
parser.add_argument('--OFP-winSize', default=(15, 15), help='')
|
||||||
|
parser.add_argument('--OFP-maxLevel', default=2, type=int, help='')
|
||||||
|
|
||||||
|
# # FeaturePoint Selection Strategy parameters
|
||||||
|
parser.add_argument('--FSS-switch', action='store_true', dest='FSS_switch')
|
||||||
|
parser.add_argument('--FSS-maxCorners', default=100, type=int, help='')
|
||||||
|
parser.add_argument('--FSS-qualityLevel', default=0.1, type=float, help='')
|
||||||
|
parser.add_argument('--FSS-minDistance', default=7, type=int, help='')
|
||||||
|
parser.add_argument('--FSS-mask', default=None, help='')
|
||||||
|
parser.add_argument('--FSS-QualityLevelRV', default=0.05, type=float, help='QualityLeve reduction value')
|
||||||
|
parser.add_argument('--FSS-FPN', default=5, type=int,
|
||||||
|
help='The number of feature points for the feature point selection strategy')
|
||||||
|
|
||||||
|
# # CCorrelation-Guided Optical Flow Method parameters
|
||||||
|
parser.add_argument('--CGOF-switch', action='store_true', dest='CGOF_switch')
|
||||||
|
|
||||||
|
# # Filter parameters
|
||||||
|
parser.add_argument('--Filter-switch', action='store_true', dest='Filter_switch')
|
||||||
|
parser.add_argument('--Filter-type', default='bandpass', help='')
|
||||||
|
parser.add_argument('--Filter-order', default=3, type=int, help='')
|
||||||
|
parser.add_argument('--Filter-LowPass', default=2, type=int, help='')
|
||||||
|
parser.add_argument('--Filter-HighPass', default=40, type=int, help='')
|
||||||
|
|
||||||
|
# # Normalization parameters
|
||||||
|
parser.add_argument('--Normalization-switch', action='store_true', dest='Normalization_switch')
|
||||||
|
|
||||||
|
# # RR Evaluation parameters
|
||||||
|
parser.add_argument('--RR-switch', action='store_true', dest='RR_switch')
|
||||||
|
|
||||||
|
# # RR Algorithm parameters
|
||||||
|
parser.add_argument('--RR-Algorithm-PC-Height', default=None, help='')
|
||||||
|
parser.add_argument('--RR-Algorithm-PC-Threshold', default=None, help='')
|
||||||
|
parser.add_argument('--RR-Algorithm-PC-MaxRR', default=45, type=int, help='')
|
||||||
|
parser.add_argument('--RR-Algorithm-CP-shfit_distance', default=15, type=int, help='')
|
||||||
|
parser.add_argument('--RR-Algorithm-NFCP-shfit_distance', default=15, type=int, help='')
|
||||||
|
parser.add_argument('--RR-Algorithm-NFCP-qualityLevel', default=0.6, type=float, help='')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
|
@ -0,0 +1,34 @@
|
||||||
|
# 基于视觉的皮肤病检测系统
|
||||||
|
|
||||||
|
该项目是一个基于图像的皮肤病检测系统。它使用MobileViT在皮肤图像数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测皮肤病类型,共支持24类。
|
||||||
|
|
||||||
|
## 核心文件
|
||||||
|
|
||||||
|
- `class_indices.json`: 包含皮肤病类型标签和对应数值编码的映射。
|
||||||
|
- `predict_api.py`: 包含图像预测模型的加载、预处理和推理逻辑。
|
||||||
|
- `video.py`: 视频处理和可视化的主要脚本。
|
||||||
|
- `best300_model_0.7302241690286009.pth`: 训练的模型权重文件。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`、`Pillow`和`dlib`。
|
||||||
|
2. 运行`video.py`脚本。
|
||||||
|
3. 脚本将打开默认摄像头,开始人脸检测和皮肤病预测。
|
||||||
|
4. 检测到的人脸周围会用矩形框标注,并显示预测的皮肤病类型和置信度分数。
|
||||||
|
5. 按`q`键退出程序。
|
||||||
|
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
该项目使用MobileViT作为基础模型,对皮肤病图像数据集进行训练,以预测人脸图像的皮肤类型。模型输出包含24个值,分别对应各皮肤病类型的概率。
|
||||||
|
|
||||||
|
### 数据集介绍
|
||||||
|
|
||||||
|
该项目使用的皮肤病图像数据集来自网络开源数据,数据集包含20000张标注了皮肤病类型的人体皮肤图像。
|
||||||
|
|
||||||
|
## 算法流程
|
||||||
|
|
||||||
|
1. **人脸检测**: 使用Dlib库中的预训练人脸检测器在视频帧中检测人脸。
|
||||||
|
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||||
|
3. **推理**: 将预处理后的图像输入到预训练的MobileViT模型中,获得不同皮肤病类型的概率预测结果。
|
||||||
|
4. **后处理**: 选取概率最高的类别作为最终预测结果。
|
||||||
|
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的皮肤病类型和置信度分数。
|
Binary file not shown.
|
@ -0,0 +1,28 @@
|
||||||
|
{
|
||||||
|
"0": "痤疮或酒渣鼻",
|
||||||
|
"1": "光化性角化病基底细胞癌或其他恶性病变",
|
||||||
|
"2": "过敏性皮炎",
|
||||||
|
"3": "大疱性疾病",
|
||||||
|
"4": "蜂窝织炎、脓疱病或其他细菌感染",
|
||||||
|
"5": "湿疹",
|
||||||
|
"6": "皮疹或药疹",
|
||||||
|
"7": "脱发或其他头发疾病",
|
||||||
|
"8": "健康",
|
||||||
|
"9": "疱疹、HPV或其他性病",
|
||||||
|
"10": "轻度疾病和色素沉着障碍",
|
||||||
|
"11": "狼疮或其他结缔组织疾病",
|
||||||
|
"12": "黑色素瘤皮肤癌痣或痣",
|
||||||
|
"13": "指甲真菌或其他指甲疾病",
|
||||||
|
"14": "毒藤或其他接触性皮炎",
|
||||||
|
"15": "牛皮癣、扁平苔藓或相关疾病",
|
||||||
|
"16": "疥疮、莱姆病或其他感染和叮咬",
|
||||||
|
"17": "脂溢性角化病或其他良性肿瘤",
|
||||||
|
"18": "全身性疾病",
|
||||||
|
"19": "癣念珠菌病或其他真菌感染",
|
||||||
|
"20": "荨麻疹",
|
||||||
|
"21": "血管肿瘤",
|
||||||
|
"22": "血管炎",
|
||||||
|
"23": "疣、软疣或其他病毒感染"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,562 @@
|
||||||
|
"""
|
||||||
|
original code from apple:
|
||||||
|
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union, Dict
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from transformer import TransformerEncoder
|
||||||
|
from model_config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(
|
||||||
|
v: Union[float, int],
|
||||||
|
divisor: Optional[int] = 8,
|
||||||
|
min_value: Optional[Union[float, int]] = None,
|
||||||
|
) -> Union[float, int]:
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by 8
|
||||||
|
It can be seen here:
|
||||||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
:param v:
|
||||||
|
:param divisor:
|
||||||
|
:param min_value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Applies a 2D convolution over an input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
|
||||||
|
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
|
||||||
|
groups (Optional[int]): Number of groups in convolution. Default: 1
|
||||||
|
bias (Optional[bool]): Use bias. Default: ``False``
|
||||||
|
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
|
||||||
|
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
|
||||||
|
Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
stride: Optional[Union[int, Tuple[int, int]]] = 1,
|
||||||
|
groups: Optional[int] = 1,
|
||||||
|
bias: Optional[bool] = False,
|
||||||
|
use_norm: Optional[bool] = True,
|
||||||
|
use_act: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size)
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride)
|
||||||
|
|
||||||
|
assert isinstance(kernel_size, Tuple)
|
||||||
|
assert isinstance(stride, Tuple)
|
||||||
|
|
||||||
|
padding = (
|
||||||
|
int((kernel_size[0] - 1) / 2),
|
||||||
|
int((kernel_size[1] - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
|
||||||
|
conv_layer = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
groups=groups,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(name="conv", module=conv_layer)
|
||||||
|
|
||||||
|
if use_norm:
|
||||||
|
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
|
||||||
|
block.add_module(name="norm", module=norm_layer)
|
||||||
|
|
||||||
|
if use_act:
|
||||||
|
act_layer = nn.SiLU()
|
||||||
|
block.add_module(name="act", module=act_layer)
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
"""
|
||||||
|
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
|
||||||
|
stride (int): Use convolutions with a stride. Default: 1
|
||||||
|
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
|
||||||
|
skip_connection (Optional[bool]): Use skip-connection. Default: True
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int,
|
||||||
|
expand_ratio: Union[int, float],
|
||||||
|
skip_connection: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
assert stride in [1, 2]
|
||||||
|
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
if expand_ratio != 1:
|
||||||
|
block.add_module(
|
||||||
|
name="exp_1x1",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
kernel_size=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(
|
||||||
|
name="conv_3x3",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
stride=stride,
|
||||||
|
kernel_size=3,
|
||||||
|
groups=hidden_dim
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(
|
||||||
|
name="red_1x1",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
use_act=False,
|
||||||
|
use_norm=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.exp = expand_ratio
|
||||||
|
self.stride = stride
|
||||||
|
self.use_res_connect = (
|
||||||
|
self.stride == 1 and in_channels == out_channels and skip_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.block(x)
|
||||||
|
else:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViTBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opts: command line arguments
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
||||||
|
transformer_dim (int): Input dimension to the transformer unit
|
||||||
|
ffn_dim (int): Dimension of the FFN block
|
||||||
|
n_transformer_blocks (int): Number of transformer blocks. Default: 2
|
||||||
|
head_dim (int): Head dimension in the multi-head attention. Default: 32
|
||||||
|
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
|
||||||
|
dropout (float): Dropout rate. Default: 0.0
|
||||||
|
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
|
||||||
|
patch_h (int): Patch height for unfolding operation. Default: 8
|
||||||
|
patch_w (int): Patch width for unfolding operation. Default: 8
|
||||||
|
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
||||||
|
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
|
||||||
|
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
transformer_dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
n_transformer_blocks: int = 2,
|
||||||
|
head_dim: int = 32,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_dropout: float = 0.0,
|
||||||
|
patch_h: int = 8,
|
||||||
|
patch_w: int = 8,
|
||||||
|
conv_ksize: Optional[int] = 3,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv_3x3_in = ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=conv_ksize,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
conv_1x1_in = ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=transformer_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
use_norm=False,
|
||||||
|
use_act=False
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_1x1_out = ConvLayer(
|
||||||
|
in_channels=transformer_dim,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
conv_3x3_out = ConvLayer(
|
||||||
|
in_channels=2 * in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=conv_ksize,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.local_rep = nn.Sequential()
|
||||||
|
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
||||||
|
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
||||||
|
|
||||||
|
assert transformer_dim % head_dim == 0
|
||||||
|
num_heads = transformer_dim // head_dim
|
||||||
|
|
||||||
|
global_rep = [
|
||||||
|
TransformerEncoder(
|
||||||
|
embed_dim=transformer_dim,
|
||||||
|
ffn_latent_dim=ffn_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_dropout=ffn_dropout
|
||||||
|
)
|
||||||
|
for _ in range(n_transformer_blocks)
|
||||||
|
]
|
||||||
|
global_rep.append(nn.LayerNorm(transformer_dim))
|
||||||
|
self.global_rep = nn.Sequential(*global_rep)
|
||||||
|
|
||||||
|
self.conv_proj = conv_1x1_out
|
||||||
|
self.fusion = conv_3x3_out
|
||||||
|
|
||||||
|
self.patch_h = patch_h
|
||||||
|
self.patch_w = patch_w
|
||||||
|
self.patch_area = self.patch_w * self.patch_h
|
||||||
|
|
||||||
|
self.cnn_in_dim = in_channels
|
||||||
|
self.cnn_out_dim = transformer_dim
|
||||||
|
self.n_heads = num_heads
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = attn_dropout
|
||||||
|
self.ffn_dropout = ffn_dropout
|
||||||
|
self.n_blocks = n_transformer_blocks
|
||||||
|
self.conv_ksize = conv_ksize
|
||||||
|
|
||||||
|
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||||
|
patch_w, patch_h = self.patch_w, self.patch_h
|
||||||
|
patch_area = patch_w * patch_h
|
||||||
|
batch_size, in_channels, orig_h, orig_w = x.shape
|
||||||
|
|
||||||
|
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
||||||
|
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
||||||
|
|
||||||
|
interpolate = False
|
||||||
|
if new_w != orig_w or new_h != orig_h:
|
||||||
|
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||||
|
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||||
|
interpolate = True
|
||||||
|
|
||||||
|
# number of patches along width and height
|
||||||
|
num_patch_w = new_w // patch_w # n_w
|
||||||
|
num_patch_h = new_h // patch_h # n_h
|
||||||
|
num_patches = num_patch_h * num_patch_w # N
|
||||||
|
|
||||||
|
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
info_dict = {
|
||||||
|
"orig_size": (orig_h, orig_w),
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"interpolate": interpolate,
|
||||||
|
"total_patches": num_patches,
|
||||||
|
"num_patches_w": num_patch_w,
|
||||||
|
"num_patches_h": num_patch_h,
|
||||||
|
}
|
||||||
|
|
||||||
|
return x, info_dict
|
||||||
|
|
||||||
|
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
|
||||||
|
n_dim = x.dim()
|
||||||
|
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
||||||
|
x.shape
|
||||||
|
)
|
||||||
|
# [BP, N, C] --> [B, P, N, C]
|
||||||
|
x = x.contiguous().view(
|
||||||
|
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, pixels, num_patches, channels = x.size()
|
||||||
|
num_patch_h = info_dict["num_patches_h"]
|
||||||
|
num_patch_w = info_dict["num_patches_w"]
|
||||||
|
|
||||||
|
# [B, P, N, C] -> [B, C, N, P]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
|
||||||
|
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
|
||||||
|
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
|
||||||
|
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
|
||||||
|
if info_dict["interpolate"]:
|
||||||
|
x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=info_dict["orig_size"],
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
res = x
|
||||||
|
|
||||||
|
fm = self.local_rep(x)
|
||||||
|
|
||||||
|
# convert feature map to patches
|
||||||
|
patches, info_dict = self.unfolding(fm)
|
||||||
|
|
||||||
|
# learn global representations
|
||||||
|
for transformer_layer in self.global_rep:
|
||||||
|
patches = transformer_layer(patches)
|
||||||
|
|
||||||
|
# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
|
||||||
|
fm = self.folding(x=patches, info_dict=info_dict)
|
||||||
|
|
||||||
|
fm = self.conv_proj(fm)
|
||||||
|
|
||||||
|
fm = self.fusion(torch.cat((res, fm), dim=1))
|
||||||
|
return fm
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViT(nn.Module):
|
||||||
|
"""
|
||||||
|
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||||
|
"""
|
||||||
|
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
image_channels = 3
|
||||||
|
out_channels = 16
|
||||||
|
|
||||||
|
self.conv_1 = ConvLayer(
|
||||||
|
in_channels=image_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
|
||||||
|
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
|
||||||
|
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
|
||||||
|
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
|
||||||
|
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
|
||||||
|
|
||||||
|
exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
|
||||||
|
self.conv_1x1_exp = ConvLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=exp_channels,
|
||||||
|
kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = nn.Sequential()
|
||||||
|
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
|
||||||
|
self.classifier.add_module(name="flatten", module=nn.Flatten())
|
||||||
|
if 0.0 < model_cfg["cls_dropout"] < 1.0:
|
||||||
|
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
|
||||||
|
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
|
||||||
|
|
||||||
|
# weight init
|
||||||
|
self.apply(self.init_parameters)
|
||||||
|
|
||||||
|
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||||
|
block_type = cfg.get("block_type", "mobilevit")
|
||||||
|
if block_type.lower() == "mobilevit":
|
||||||
|
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
|
||||||
|
else:
|
||||||
|
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||||
|
output_channels = cfg.get("out_channels")
|
||||||
|
num_blocks = cfg.get("num_blocks", 2)
|
||||||
|
expand_ratio = cfg.get("expand_ratio", 4)
|
||||||
|
block = []
|
||||||
|
|
||||||
|
for i in range(num_blocks):
|
||||||
|
stride = cfg.get("stride", 1) if i == 0 else 1
|
||||||
|
|
||||||
|
layer = InvertedResidual(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channels,
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=expand_ratio
|
||||||
|
)
|
||||||
|
block.append(layer)
|
||||||
|
input_channel = output_channels
|
||||||
|
|
||||||
|
return nn.Sequential(*block), input_channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
|
||||||
|
stride = cfg.get("stride", 1)
|
||||||
|
block = []
|
||||||
|
|
||||||
|
if stride == 2:
|
||||||
|
layer = InvertedResidual(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=cfg.get("out_channels"),
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=cfg.get("mv_expand_ratio", 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
block.append(layer)
|
||||||
|
input_channel = cfg.get("out_channels")
|
||||||
|
|
||||||
|
transformer_dim = cfg["transformer_channels"]
|
||||||
|
ffn_dim = cfg.get("ffn_dim")
|
||||||
|
num_heads = cfg.get("num_heads", 4)
|
||||||
|
head_dim = transformer_dim // num_heads
|
||||||
|
|
||||||
|
if transformer_dim % head_dim != 0:
|
||||||
|
raise ValueError("Transformer input dimension should be divisible by head dimension. "
|
||||||
|
"Got {} and {}.".format(transformer_dim, head_dim))
|
||||||
|
|
||||||
|
block.append(MobileViTBlock(
|
||||||
|
in_channels=input_channel,
|
||||||
|
transformer_dim=transformer_dim,
|
||||||
|
ffn_dim=ffn_dim,
|
||||||
|
n_transformer_blocks=cfg.get("transformer_blocks", 1),
|
||||||
|
patch_h=cfg.get("patch_h", 2),
|
||||||
|
patch_w=cfg.get("patch_w", 2),
|
||||||
|
dropout=cfg.get("dropout", 0.1),
|
||||||
|
ffn_dropout=cfg.get("ffn_dropout", 0.0),
|
||||||
|
attn_dropout=cfg.get("attn_dropout", 0.1),
|
||||||
|
head_dim=head_dim,
|
||||||
|
conv_ksize=3
|
||||||
|
))
|
||||||
|
|
||||||
|
return nn.Sequential(*block), input_channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_parameters(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.Linear,)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.conv_1(x)
|
||||||
|
x = self.layer_1(x)
|
||||||
|
x = self.layer_2(x)
|
||||||
|
|
||||||
|
x = self.layer_3(x)
|
||||||
|
x = self.layer_4(x)
|
||||||
|
x = self.layer_5(x)
|
||||||
|
x = self.conv_1x1_exp(x)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_xx_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
|
||||||
|
config = get_config("xx_small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_x_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
|
||||||
|
config = get_config("x_small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
|
||||||
|
config = get_config("small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
|
@ -0,0 +1,176 @@
|
||||||
|
def get_config(mode: str = "xxs") -> dict:
|
||||||
|
if mode == "xx_small":
|
||||||
|
mv2_exp_mult = 2
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 16,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 24,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 48,
|
||||||
|
"transformer_channels": 64,
|
||||||
|
"ffn_dim": 128,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2, # 8,
|
||||||
|
"patch_w": 2, # 8,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 64,
|
||||||
|
"transformer_channels": 80,
|
||||||
|
"ffn_dim": 160,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2, # 4,
|
||||||
|
"patch_w": 2, # 4,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 80,
|
||||||
|
"transformer_channels": 96,
|
||||||
|
"ffn_dim": 192,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
elif mode == "x_small":
|
||||||
|
mv2_exp_mult = 4
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 32,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 48,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 64,
|
||||||
|
"transformer_channels": 96,
|
||||||
|
"ffn_dim": 192,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 80,
|
||||||
|
"transformer_channels": 120,
|
||||||
|
"ffn_dim": 240,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 96,
|
||||||
|
"transformer_channels": 144,
|
||||||
|
"ffn_dim": 288,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
elif mode == "small":
|
||||||
|
mv2_exp_mult = 4
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 32,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 64,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 96,
|
||||||
|
"transformer_channels": 144,
|
||||||
|
"ffn_dim": 288,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 128,
|
||||||
|
"transformer_channels": 192,
|
||||||
|
"ffn_dim": 384,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 160,
|
||||||
|
"transformer_channels": 240,
|
||||||
|
"ffn_dim": 480,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
||||||
|
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
||||||
|
|
||||||
|
return config
|
|
@ -0,0 +1,37 @@
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataSet(Dataset):
|
||||||
|
"""自定义数据集"""
|
||||||
|
|
||||||
|
def __init__(self, images_path: list, images_class: list, transform=None):
|
||||||
|
self.images_path = images_path
|
||||||
|
self.images_class = images_class
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images_path)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img = Image.open(self.images_path[item])
|
||||||
|
# RGB为彩色图片,L为灰度图片
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
||||||
|
label = self.images_class[item]
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
# 官方实现的default_collate可以参考
|
||||||
|
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
||||||
|
images, labels = tuple(zip(*batch))
|
||||||
|
|
||||||
|
images = torch.stack(images, dim=0)
|
||||||
|
labels = torch.as_tensor(labels)
|
||||||
|
return images, labels
|
|
@ -0,0 +1,64 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from model import mobile_vit_small as create_model
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
|
|
||||||
|
#设置plt支持中文
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
img_size = 224
|
||||||
|
data_transform = transforms.Compose(
|
||||||
|
[transforms.Resize(int(img_size * 1.14)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
# load image
|
||||||
|
img_path = r"E:\Download\data\train\Acne and Rosacea Photos\acne-closed-comedo-8.jpg"
|
||||||
|
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||||
|
img = Image.open(img_path)
|
||||||
|
plt.imshow(img)
|
||||||
|
# [N, C, H, W]
|
||||||
|
img = data_transform(img)
|
||||||
|
# expand batch dimension
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# read class_indict
|
||||||
|
json_path = './class_indices.json'
|
||||||
|
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||||
|
|
||||||
|
with open(json_path, "r",encoding="utf-8") as f:
|
||||||
|
class_indict = json.load(f)
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = create_model(num_classes=24).to(device)
|
||||||
|
# load model weights
|
||||||
|
model_weight_path = "./best300_model_0.7302241690286009.pth"
|
||||||
|
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# predict class
|
||||||
|
output = torch.squeeze(model(img.to(device))).cpu()
|
||||||
|
predict = torch.softmax(output, dim=0)
|
||||||
|
predict_cla = torch.argmax(predict).numpy()
|
||||||
|
|
||||||
|
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||||
|
predict[predict_cla].numpy())
|
||||||
|
plt.title(print_res)
|
||||||
|
for i in range(len(predict)):
|
||||||
|
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||||
|
predict[i].numpy()))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,90 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from model import mobile_vit_small as create_model
|
||||||
|
|
||||||
|
class ImagePredictor:
|
||||||
|
def __init__(self, model_path, class_indices_path, img_size=224):
|
||||||
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.img_size = img_size
|
||||||
|
self.data_transform = transforms.Compose([
|
||||||
|
transforms.Resize(int(self.img_size * 1.14)),
|
||||||
|
transforms.CenterCrop(self.img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
# Load class indices
|
||||||
|
with open(class_indices_path, "r",encoding="utf-8") as f:
|
||||||
|
self.class_indict = json.load(f)
|
||||||
|
# Load model
|
||||||
|
self.model = self.load_model(model_path)
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
|
||||||
|
model = create_model(num_classes=24).to(self.device)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict_img(self, image_path):
|
||||||
|
# Load and transform image
|
||||||
|
assert os.path.exists(image_path), f"file: '{image_path}' does not exist."
|
||||||
|
img = Image.open(image_path).convert('RGB')
|
||||||
|
img = self.data_transform(img)
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# Predict class
|
||||||
|
with torch.no_grad():
|
||||||
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||||
|
probabilities = torch.softmax(output, dim=0)
|
||||||
|
top_prob, top_catid = torch.topk(probabilities, 5)
|
||||||
|
|
||||||
|
# Top 5 results
|
||||||
|
top5 = []
|
||||||
|
for i in range(top_prob.size(0)):
|
||||||
|
top5.append({
|
||||||
|
"name": self.class_indict[str(top_catid[i].item())],
|
||||||
|
"score": top_prob[i].item(),
|
||||||
|
"label": top_catid[i].item()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Results dictionary
|
||||||
|
|
||||||
|
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
||||||
|
|
||||||
|
return results
|
||||||
|
def predict(self, np_image):
|
||||||
|
# Convert numpy image to PIL image
|
||||||
|
img = Image.fromarray(np_image).convert('RGB')
|
||||||
|
|
||||||
|
# Transform image
|
||||||
|
img = self.data_transform(img)
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# Predict class
|
||||||
|
with torch.no_grad():
|
||||||
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||||
|
probabilities = torch.softmax(output, dim=0)
|
||||||
|
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||||
|
|
||||||
|
# Top 5 results
|
||||||
|
top5 = []
|
||||||
|
for i in range(top_prob.size(0)):
|
||||||
|
top5.append({
|
||||||
|
"name": self.class_indict[str(top_catid[i].item())],
|
||||||
|
"score": top_prob[i].item(),
|
||||||
|
"label": top_catid[i].item()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Results dictionary
|
||||||
|
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
||||||
|
# result = predictor.predict("../tulip.jpg")
|
||||||
|
# print(result)
|
|
@ -0,0 +1,135 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from my_dataset import MyDataSet
|
||||||
|
from model import mobile_vit_xx_small as create_model
|
||||||
|
from utils import read_split_data, train_one_epoch, evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if os.path.exists("./weights") is False:
|
||||||
|
os.makedirs("./weights")
|
||||||
|
|
||||||
|
tb_writer = SummaryWriter()
|
||||||
|
|
||||||
|
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
||||||
|
|
||||||
|
img_size = 224
|
||||||
|
data_transform = {
|
||||||
|
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
|
||||||
|
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
|
||||||
|
|
||||||
|
# 实例化训练数据集
|
||||||
|
train_dataset = MyDataSet(images_path=train_images_path,
|
||||||
|
images_class=train_images_label,
|
||||||
|
transform=data_transform["train"])
|
||||||
|
|
||||||
|
# 实例化验证数据集
|
||||||
|
val_dataset = MyDataSet(images_path=val_images_path,
|
||||||
|
images_class=val_images_label,
|
||||||
|
transform=data_transform["val"])
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||||
|
print('Using {} dataloader workers every process'.format(nw))
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=nw,
|
||||||
|
collate_fn=train_dataset.collate_fn)
|
||||||
|
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=nw,
|
||||||
|
collate_fn=val_dataset.collate_fn)
|
||||||
|
|
||||||
|
model = create_model(num_classes=args.num_classes).to(device)
|
||||||
|
|
||||||
|
if args.weights != "":
|
||||||
|
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
|
||||||
|
weights_dict = torch.load(args.weights, map_location=device)
|
||||||
|
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
|
||||||
|
# 删除有关分类类别的权重
|
||||||
|
for k in list(weights_dict.keys()):
|
||||||
|
if "classifier" in k:
|
||||||
|
del weights_dict[k]
|
||||||
|
print(model.load_state_dict(weights_dict, strict=False))
|
||||||
|
|
||||||
|
if args.freeze_layers:
|
||||||
|
for name, para in model.named_parameters():
|
||||||
|
# 除head外,其他权重全部冻结
|
||||||
|
if "classifier" not in name:
|
||||||
|
para.requires_grad_(False)
|
||||||
|
else:
|
||||||
|
print("training {}".format(name))
|
||||||
|
|
||||||
|
pg = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
|
||||||
|
|
||||||
|
best_acc = 0.
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
# train
|
||||||
|
train_loss, train_acc = train_one_epoch(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
data_loader=train_loader,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch)
|
||||||
|
|
||||||
|
# validate
|
||||||
|
val_loss, val_acc = evaluate(model=model,
|
||||||
|
data_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch)
|
||||||
|
|
||||||
|
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
||||||
|
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
||||||
|
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
||||||
|
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
||||||
|
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
||||||
|
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
||||||
|
|
||||||
|
if val_acc > best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
torch.save(model.state_dict(), "./weights/best_model.pth")
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), "./weights/latest_model.pth")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--num_classes', type=int, default=5)
|
||||||
|
parser.add_argument('--epochs', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0002)
|
||||||
|
|
||||||
|
# 数据集所在根目录
|
||||||
|
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||||||
|
parser.add_argument('--data-path', type=str,
|
||||||
|
default="/data/flower_photos")
|
||||||
|
|
||||||
|
# 预训练权重路径,如果不想载入就设置为空字符
|
||||||
|
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
|
||||||
|
help='initial weights path')
|
||||||
|
# 是否冻结权重
|
||||||
|
parser.add_argument('--freeze-layers', type=bool, default=False)
|
||||||
|
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
main(opt)
|
|
@ -0,0 +1,155 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
This layer applies a multi-head self- or cross-attention as described in
|
||||||
|
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||||
|
num_heads (int): Number of heads in multi-head attention
|
||||||
|
attn_dropout (float): Attention dropout. Default: 0.0
|
||||||
|
bias (bool): Use bias or not. Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||||
|
and :math:`C_{in}` is input embedding dim
|
||||||
|
- Output: same shape as the input
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if embed_dim % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
||||||
|
self.__class__.__name__, embed_dim, num_heads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||||
|
|
||||||
|
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||||
|
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def forward(self, x_q: Tensor) -> Tensor:
|
||||||
|
# [N, P, C]
|
||||||
|
b_sz, n_patches, in_channels = x_q.shape
|
||||||
|
|
||||||
|
# self-attention
|
||||||
|
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
||||||
|
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
||||||
|
|
||||||
|
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
||||||
|
qkv = qkv.transpose(1, 3).contiguous()
|
||||||
|
|
||||||
|
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
||||||
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||||
|
|
||||||
|
query = query * self.scaling
|
||||||
|
|
||||||
|
# [N h, P, c] -> [N, h, c, P]
|
||||||
|
key = key.transpose(-1, -2)
|
||||||
|
|
||||||
|
# QK^T
|
||||||
|
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
||||||
|
attn = torch.matmul(query, key)
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
# weighted sum
|
||||||
|
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
||||||
|
out = torch.matmul(attn, value)
|
||||||
|
|
||||||
|
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
||||||
|
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
||||||
|
out = self.out_proj(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
||||||
|
Args:
|
||||||
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||||
|
ffn_latent_dim (int): Inner dimension of the FFN
|
||||||
|
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
||||||
|
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||||
|
dropout (float): Dropout rate. Default: 0.0
|
||||||
|
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||||
|
and :math:`C_{in}` is input embedding dim
|
||||||
|
- Output: same shape as the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
ffn_latent_dim: int,
|
||||||
|
num_heads: Optional[int] = 8,
|
||||||
|
attn_dropout: Optional[float] = 0.0,
|
||||||
|
dropout: Optional[float] = 0.0,
|
||||||
|
ffn_dropout: Optional[float] = 0.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
attn_unit = MultiHeadAttention(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_mha = nn.Sequential(
|
||||||
|
nn.LayerNorm(embed_dim),
|
||||||
|
attn_unit,
|
||||||
|
nn.Dropout(p=dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_ffn = nn.Sequential(
|
||||||
|
nn.LayerNorm(embed_dim),
|
||||||
|
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=ffn_dropout),
|
||||||
|
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||||
|
nn.Dropout(p=dropout)
|
||||||
|
)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.ffn_dim = ffn_latent_dim
|
||||||
|
self.ffn_dropout = ffn_dropout
|
||||||
|
self.std_dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# multi-head attention
|
||||||
|
res = x
|
||||||
|
x = self.pre_norm_mha(x)
|
||||||
|
x = x + res
|
||||||
|
|
||||||
|
# feed forward network
|
||||||
|
x = x + self.pre_norm_ffn(x)
|
||||||
|
return x
|
|
@ -0,0 +1,56 @@
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
in_channels = 32
|
||||||
|
patch_h = 2
|
||||||
|
patch_w = 2
|
||||||
|
num_patch_h = 16
|
||||||
|
num_patch_w = 16
|
||||||
|
num_patches = num_patch_h * num_patch_w
|
||||||
|
patch_area = patch_h * patch_w
|
||||||
|
|
||||||
|
|
||||||
|
def official(x: torch.Tensor):
|
||||||
|
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def my_self(x: torch.Tensor):
|
||||||
|
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(3, 4)
|
||||||
|
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||||
|
print(torch.equal(official(t), my_self(t)))
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
official(t)
|
||||||
|
print(f"official time: {time.time() - t1}")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
my_self(t)
|
||||||
|
print(f"self time: {time.time() - t1}")
|
|
@ -0,0 +1,179 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def read_split_data(root: str, val_rate: float = 0.2):
|
||||||
|
random.seed(0) # 保证随机结果可复现
|
||||||
|
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||||||
|
|
||||||
|
# 遍历文件夹,一个文件夹对应一个类别
|
||||||
|
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||||||
|
# 排序,保证各平台顺序一致
|
||||||
|
flower_class.sort()
|
||||||
|
# 生成类别名称以及对应的数字索引
|
||||||
|
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||||||
|
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||||||
|
with open('class_indices.json', 'w') as json_file:
|
||||||
|
json_file.write(json_str)
|
||||||
|
|
||||||
|
train_images_path = [] # 存储训练集的所有图片路径
|
||||||
|
train_images_label = [] # 存储训练集图片对应索引信息
|
||||||
|
val_images_path = [] # 存储验证集的所有图片路径
|
||||||
|
val_images_label = [] # 存储验证集图片对应索引信息
|
||||||
|
every_class_num = [] # 存储每个类别的样本总数
|
||||||
|
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||||||
|
# 遍历每个文件夹下的文件
|
||||||
|
for cla in flower_class:
|
||||||
|
cla_path = os.path.join(root, cla)
|
||||||
|
# 遍历获取supported支持的所有文件路径
|
||||||
|
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||||||
|
if os.path.splitext(i)[-1] in supported]
|
||||||
|
# 排序,保证各平台顺序一致
|
||||||
|
images.sort()
|
||||||
|
# 获取该类别对应的索引
|
||||||
|
image_class = class_indices[cla]
|
||||||
|
# 记录该类别的样本数量
|
||||||
|
every_class_num.append(len(images))
|
||||||
|
# 按比例随机采样验证样本
|
||||||
|
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||||||
|
|
||||||
|
for img_path in images:
|
||||||
|
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||||||
|
val_images_path.append(img_path)
|
||||||
|
val_images_label.append(image_class)
|
||||||
|
else: # 否则存入训练集
|
||||||
|
train_images_path.append(img_path)
|
||||||
|
train_images_label.append(image_class)
|
||||||
|
|
||||||
|
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||||||
|
print("{} images for training.".format(len(train_images_path)))
|
||||||
|
print("{} images for validation.".format(len(val_images_path)))
|
||||||
|
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||||||
|
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||||||
|
|
||||||
|
plot_image = False
|
||||||
|
if plot_image:
|
||||||
|
# 绘制每种类别个数柱状图
|
||||||
|
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||||||
|
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||||||
|
plt.xticks(range(len(flower_class)), flower_class)
|
||||||
|
# 在柱状图上添加数值标签
|
||||||
|
for i, v in enumerate(every_class_num):
|
||||||
|
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||||
|
# 设置x坐标
|
||||||
|
plt.xlabel('image class')
|
||||||
|
# 设置y坐标
|
||||||
|
plt.ylabel('number of images')
|
||||||
|
# 设置柱状图的标题
|
||||||
|
plt.title('flower class distribution')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return train_images_path, train_images_label, val_images_path, val_images_label
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data_loader_image(data_loader):
|
||||||
|
batch_size = data_loader.batch_size
|
||||||
|
plot_num = min(batch_size, 4)
|
||||||
|
|
||||||
|
json_path = './class_indices.json'
|
||||||
|
assert os.path.exists(json_path), json_path + " does not exist."
|
||||||
|
json_file = open(json_path, 'r')
|
||||||
|
class_indices = json.load(json_file)
|
||||||
|
|
||||||
|
for data in data_loader:
|
||||||
|
images, labels = data
|
||||||
|
for i in range(plot_num):
|
||||||
|
# [C, H, W] -> [H, W, C]
|
||||||
|
img = images[i].numpy().transpose(1, 2, 0)
|
||||||
|
# 反Normalize操作
|
||||||
|
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||||||
|
label = labels[i].item()
|
||||||
|
plt.subplot(1, plot_num, i+1)
|
||||||
|
plt.xlabel(class_indices[str(label)])
|
||||||
|
plt.xticks([]) # 去掉x轴的刻度
|
||||||
|
plt.yticks([]) # 去掉y轴的刻度
|
||||||
|
plt.imshow(img.astype('uint8'))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def write_pickle(list_info: list, file_name: str):
|
||||||
|
with open(file_name, 'wb') as f:
|
||||||
|
pickle.dump(list_info, f)
|
||||||
|
|
||||||
|
|
||||||
|
def read_pickle(file_name: str) -> list:
|
||||||
|
with open(file_name, 'rb') as f:
|
||||||
|
info_list = pickle.load(f)
|
||||||
|
return info_list
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||||
|
model.train()
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sample_num = 0
|
||||||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||||
|
for step, data in enumerate(data_loader):
|
||||||
|
images, labels = data
|
||||||
|
sample_num += images.shape[0]
|
||||||
|
|
||||||
|
pred = model(images.to(device))
|
||||||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||||||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||||
|
|
||||||
|
loss = loss_function(pred, labels.to(device))
|
||||||
|
loss.backward()
|
||||||
|
accu_loss += loss.detach()
|
||||||
|
|
||||||
|
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||||
|
accu_loss.item() / (step + 1),
|
||||||
|
accu_num.item() / sample_num)
|
||||||
|
|
||||||
|
if not torch.isfinite(loss):
|
||||||
|
print('WARNING: non-finite loss, ending training ', loss)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(model, data_loader, device, epoch):
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||||
|
|
||||||
|
sample_num = 0
|
||||||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||||
|
for step, data in enumerate(data_loader):
|
||||||
|
images, labels = data
|
||||||
|
sample_num += images.shape[0]
|
||||||
|
|
||||||
|
pred = model(images.to(device))
|
||||||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||||||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||||
|
|
||||||
|
loss = loss_function(pred, labels.to(device))
|
||||||
|
accu_loss += loss
|
||||||
|
|
||||||
|
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||||
|
accu_loss.item() / (step + 1),
|
||||||
|
accu_num.item() / sample_num)
|
||||||
|
|
||||||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
|
@ -0,0 +1,51 @@
|
||||||
|
import cv2
|
||||||
|
import dlib
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from predict_api import ImagePredictor
|
||||||
|
|
||||||
|
# Initialize camera and face detector
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
detector = dlib.get_frontal_face_detector()
|
||||||
|
|
||||||
|
# Initialize ImagePredictor
|
||||||
|
predictor = ImagePredictor(model_path="best300_model_0.7302241690286009.pth", class_indices_path="./class_indices.json")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Capture frame-by-frame
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
# Convert the image from BGR color (which OpenCV uses) to RGB color
|
||||||
|
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Perform face detection
|
||||||
|
faces = detector(rgb_image)
|
||||||
|
|
||||||
|
# Loop through each face in this frame
|
||||||
|
for rect in faces:
|
||||||
|
# Get the bounding box coordinates
|
||||||
|
x1, y1, x2, y2 = rect.left(), rect.top(), rect.right(), rect.bottom()
|
||||||
|
|
||||||
|
# Crop the face from the frame
|
||||||
|
face_image = rgb_image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
# Use ImagePredictor to predict the class of this face
|
||||||
|
result = predictor.predict(face_image)
|
||||||
|
|
||||||
|
# Draw a rectangle around the face
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Display the class name and score
|
||||||
|
cv2.putText(frame, f"{result['result'][0]['name']}: {round(result['result'][0]['score'],4)}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
|
||||||
|
|
||||||
|
# Display the resulting frame
|
||||||
|
cv2.imshow('Video', frame)
|
||||||
|
|
||||||
|
# Exit loop if 'q' is pressed
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
# When everything is done, release the capture
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
|
@ -0,0 +1,34 @@
|
||||||
|
# 基于视觉的皮肤类型检测系统
|
||||||
|
|
||||||
|
该项目是一个基于图像的皮肤类型检测系统。它使用MobileViT在皮肤图像数据集上进行训练,然后可以从摄像头输入的视频中检测人脸,并为每个检测到的人脸预测皮肤类型(干性、正常或油性)。
|
||||||
|
|
||||||
|
## 核心文件
|
||||||
|
|
||||||
|
- `class_indices.json`: 包含皮肤类型标签和对应数值编码的映射。
|
||||||
|
- `predict_api.py`: 包含图像预测模型的加载、预处理和推理逻辑。
|
||||||
|
- `video.py`: 视频处理和可视化的主要脚本。
|
||||||
|
- `best_model_'0.8998410174880763'.pth`: 在皮肤图像数据集上训练的模型权重文件。
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
1. 确保已安装所需的Python库,包括`opencv-python`、`torch`、`torchvision`、`Pillow`和`dlib`。
|
||||||
|
2. 运行`video.py`脚本。
|
||||||
|
3. 脚本将打开默认摄像头,开始人脸检测和皮肤类型预测。
|
||||||
|
4. 检测到的人脸周围会用矩形框标注,并显示预测的皮肤类型和置信度分数。
|
||||||
|
5. 按`q`键退出程序。
|
||||||
|
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
该项目使用MobileViT作为基础模型,对皮肤图像数据集进行训练,以预测人脸图像的皮肤类型。模型输出包含3个值,分别对应干性、正常和油性皮肤类型的概率。
|
||||||
|
|
||||||
|
### 数据集介绍
|
||||||
|
|
||||||
|
该项目使用的皮肤图像数据集来自Kaggle平台,数据集包含3152张标注了皮肤类型(干性、正常或油性)的人脸图像。
|
||||||
|
|
||||||
|
## 算法流程
|
||||||
|
|
||||||
|
1. **人脸检测**: 使用Dlib库中的预训练人脸检测器在视频帧中检测人脸。
|
||||||
|
2. **预处理**: 对检测到的人脸图像进行缩放、裁剪和标准化等预处理,以满足模型的输入要求。
|
||||||
|
3. **推理**: 将预处理后的图像输入到预训练的Mobile-ViT模型中,获得不同皮肤类型的概率预测结果。
|
||||||
|
4. **后处理**: 选取概率最高的类别作为最终预测结果。
|
||||||
|
5. **可视化**: 在视频帧上绘制人脸矩形框,并显示预测的皮肤类型和置信度分数。
|
Binary file not shown.
|
@ -0,0 +1,7 @@
|
||||||
|
{
|
||||||
|
"0": "Dry",
|
||||||
|
"1": "Normal",
|
||||||
|
"2": "Oily"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,562 @@
|
||||||
|
"""
|
||||||
|
original code from apple:
|
||||||
|
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union, Dict
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from transformer import TransformerEncoder
|
||||||
|
from model_config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(
|
||||||
|
v: Union[float, int],
|
||||||
|
divisor: Optional[int] = 8,
|
||||||
|
min_value: Optional[Union[float, int]] = None,
|
||||||
|
) -> Union[float, int]:
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by 8
|
||||||
|
It can be seen here:
|
||||||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
:param v:
|
||||||
|
:param divisor:
|
||||||
|
:param min_value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Applies a 2D convolution over an input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
|
||||||
|
stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
|
||||||
|
groups (Optional[int]): Number of groups in convolution. Default: 1
|
||||||
|
bias (Optional[bool]): Use bias. Default: ``False``
|
||||||
|
use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
|
||||||
|
use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
|
||||||
|
Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
stride: Optional[Union[int, Tuple[int, int]]] = 1,
|
||||||
|
groups: Optional[int] = 1,
|
||||||
|
bias: Optional[bool] = False,
|
||||||
|
use_norm: Optional[bool] = True,
|
||||||
|
use_act: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size)
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride)
|
||||||
|
|
||||||
|
assert isinstance(kernel_size, Tuple)
|
||||||
|
assert isinstance(stride, Tuple)
|
||||||
|
|
||||||
|
padding = (
|
||||||
|
int((kernel_size[0] - 1) / 2),
|
||||||
|
int((kernel_size[1] - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
|
||||||
|
conv_layer = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
groups=groups,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(name="conv", module=conv_layer)
|
||||||
|
|
||||||
|
if use_norm:
|
||||||
|
norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
|
||||||
|
block.add_module(name="norm", module=norm_layer)
|
||||||
|
|
||||||
|
if use_act:
|
||||||
|
act_layer = nn.SiLU()
|
||||||
|
block.add_module(name="act", module=act_layer)
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
"""
|
||||||
|
This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
|
||||||
|
stride (int): Use convolutions with a stride. Default: 1
|
||||||
|
expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
|
||||||
|
skip_connection (Optional[bool]): Use skip-connection. Default: True
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
|
||||||
|
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int,
|
||||||
|
expand_ratio: Union[int, float],
|
||||||
|
skip_connection: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
assert stride in [1, 2]
|
||||||
|
hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
block = nn.Sequential()
|
||||||
|
if expand_ratio != 1:
|
||||||
|
block.add_module(
|
||||||
|
name="exp_1x1",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
kernel_size=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(
|
||||||
|
name="conv_3x3",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
stride=stride,
|
||||||
|
kernel_size=3,
|
||||||
|
groups=hidden_dim
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
block.add_module(
|
||||||
|
name="red_1x1",
|
||||||
|
module=ConvLayer(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
use_act=False,
|
||||||
|
use_norm=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block = block
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.exp = expand_ratio
|
||||||
|
self.stride = stride
|
||||||
|
self.use_res_connect = (
|
||||||
|
self.stride == 1 and in_channels == out_channels and skip_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.block(x)
|
||||||
|
else:
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViTBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opts: command line arguments
|
||||||
|
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
||||||
|
transformer_dim (int): Input dimension to the transformer unit
|
||||||
|
ffn_dim (int): Dimension of the FFN block
|
||||||
|
n_transformer_blocks (int): Number of transformer blocks. Default: 2
|
||||||
|
head_dim (int): Head dimension in the multi-head attention. Default: 32
|
||||||
|
attn_dropout (float): Dropout in multi-head attention. Default: 0.0
|
||||||
|
dropout (float): Dropout rate. Default: 0.0
|
||||||
|
ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
|
||||||
|
patch_h (int): Patch height for unfolding operation. Default: 8
|
||||||
|
patch_w (int): Patch width for unfolding operation. Default: 8
|
||||||
|
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
||||||
|
conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
|
||||||
|
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
transformer_dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
n_transformer_blocks: int = 2,
|
||||||
|
head_dim: int = 32,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_dropout: float = 0.0,
|
||||||
|
patch_h: int = 8,
|
||||||
|
patch_w: int = 8,
|
||||||
|
conv_ksize: Optional[int] = 3,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv_3x3_in = ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=conv_ksize,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
conv_1x1_in = ConvLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=transformer_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
use_norm=False,
|
||||||
|
use_act=False
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_1x1_out = ConvLayer(
|
||||||
|
in_channels=transformer_dim,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
conv_3x3_out = ConvLayer(
|
||||||
|
in_channels=2 * in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=conv_ksize,
|
||||||
|
stride=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.local_rep = nn.Sequential()
|
||||||
|
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
||||||
|
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
||||||
|
|
||||||
|
assert transformer_dim % head_dim == 0
|
||||||
|
num_heads = transformer_dim // head_dim
|
||||||
|
|
||||||
|
global_rep = [
|
||||||
|
TransformerEncoder(
|
||||||
|
embed_dim=transformer_dim,
|
||||||
|
ffn_latent_dim=ffn_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_dropout=ffn_dropout
|
||||||
|
)
|
||||||
|
for _ in range(n_transformer_blocks)
|
||||||
|
]
|
||||||
|
global_rep.append(nn.LayerNorm(transformer_dim))
|
||||||
|
self.global_rep = nn.Sequential(*global_rep)
|
||||||
|
|
||||||
|
self.conv_proj = conv_1x1_out
|
||||||
|
self.fusion = conv_3x3_out
|
||||||
|
|
||||||
|
self.patch_h = patch_h
|
||||||
|
self.patch_w = patch_w
|
||||||
|
self.patch_area = self.patch_w * self.patch_h
|
||||||
|
|
||||||
|
self.cnn_in_dim = in_channels
|
||||||
|
self.cnn_out_dim = transformer_dim
|
||||||
|
self.n_heads = num_heads
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = attn_dropout
|
||||||
|
self.ffn_dropout = ffn_dropout
|
||||||
|
self.n_blocks = n_transformer_blocks
|
||||||
|
self.conv_ksize = conv_ksize
|
||||||
|
|
||||||
|
def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||||||
|
patch_w, patch_h = self.patch_w, self.patch_h
|
||||||
|
patch_area = patch_w * patch_h
|
||||||
|
batch_size, in_channels, orig_h, orig_w = x.shape
|
||||||
|
|
||||||
|
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
||||||
|
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
||||||
|
|
||||||
|
interpolate = False
|
||||||
|
if new_w != orig_w or new_h != orig_h:
|
||||||
|
# Note: Padding can be done, but then it needs to be handled in attention function.
|
||||||
|
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||||
|
interpolate = True
|
||||||
|
|
||||||
|
# number of patches along width and height
|
||||||
|
num_patch_w = new_w // patch_w # n_w
|
||||||
|
num_patch_h = new_h // patch_h # n_h
|
||||||
|
num_patches = num_patch_h * num_patch_w # N
|
||||||
|
|
||||||
|
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
info_dict = {
|
||||||
|
"orig_size": (orig_h, orig_w),
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"interpolate": interpolate,
|
||||||
|
"total_patches": num_patches,
|
||||||
|
"num_patches_w": num_patch_w,
|
||||||
|
"num_patches_h": num_patch_h,
|
||||||
|
}
|
||||||
|
|
||||||
|
return x, info_dict
|
||||||
|
|
||||||
|
def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
|
||||||
|
n_dim = x.dim()
|
||||||
|
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
||||||
|
x.shape
|
||||||
|
)
|
||||||
|
# [BP, N, C] --> [B, P, N, C]
|
||||||
|
x = x.contiguous().view(
|
||||||
|
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, pixels, num_patches, channels = x.size()
|
||||||
|
num_patch_h = info_dict["num_patches_h"]
|
||||||
|
num_patch_w = info_dict["num_patches_w"]
|
||||||
|
|
||||||
|
# [B, P, N, C] -> [B, C, N, P]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
|
||||||
|
x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
|
||||||
|
# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
|
||||||
|
x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
|
||||||
|
if info_dict["interpolate"]:
|
||||||
|
x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=info_dict["orig_size"],
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
res = x
|
||||||
|
|
||||||
|
fm = self.local_rep(x)
|
||||||
|
|
||||||
|
# convert feature map to patches
|
||||||
|
patches, info_dict = self.unfolding(fm)
|
||||||
|
|
||||||
|
# learn global representations
|
||||||
|
for transformer_layer in self.global_rep:
|
||||||
|
patches = transformer_layer(patches)
|
||||||
|
|
||||||
|
# [B x Patch x Patches x C] -> [B x C x Patches x Patch]
|
||||||
|
fm = self.folding(x=patches, info_dict=info_dict)
|
||||||
|
|
||||||
|
fm = self.conv_proj(fm)
|
||||||
|
|
||||||
|
fm = self.fusion(torch.cat((res, fm), dim=1))
|
||||||
|
return fm
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViT(nn.Module):
|
||||||
|
"""
|
||||||
|
This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
||||||
|
"""
|
||||||
|
def __init__(self, model_cfg: Dict, num_classes: int = 1000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
image_channels = 3
|
||||||
|
out_channels = 16
|
||||||
|
|
||||||
|
self.conv_1 = ConvLayer(
|
||||||
|
in_channels=image_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
|
||||||
|
self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
|
||||||
|
self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
|
||||||
|
self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
|
||||||
|
self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
|
||||||
|
|
||||||
|
exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
|
||||||
|
self.conv_1x1_exp = ConvLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=exp_channels,
|
||||||
|
kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = nn.Sequential()
|
||||||
|
self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
|
||||||
|
self.classifier.add_module(name="flatten", module=nn.Flatten())
|
||||||
|
if 0.0 < model_cfg["cls_dropout"] < 1.0:
|
||||||
|
self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
|
||||||
|
self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
|
||||||
|
|
||||||
|
# weight init
|
||||||
|
self.apply(self.init_parameters)
|
||||||
|
|
||||||
|
def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||||
|
block_type = cfg.get("block_type", "mobilevit")
|
||||||
|
if block_type.lower() == "mobilevit":
|
||||||
|
return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
|
||||||
|
else:
|
||||||
|
return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
|
||||||
|
output_channels = cfg.get("out_channels")
|
||||||
|
num_blocks = cfg.get("num_blocks", 2)
|
||||||
|
expand_ratio = cfg.get("expand_ratio", 4)
|
||||||
|
block = []
|
||||||
|
|
||||||
|
for i in range(num_blocks):
|
||||||
|
stride = cfg.get("stride", 1) if i == 0 else 1
|
||||||
|
|
||||||
|
layer = InvertedResidual(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channels,
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=expand_ratio
|
||||||
|
)
|
||||||
|
block.append(layer)
|
||||||
|
input_channel = output_channels
|
||||||
|
|
||||||
|
return nn.Sequential(*block), input_channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
|
||||||
|
stride = cfg.get("stride", 1)
|
||||||
|
block = []
|
||||||
|
|
||||||
|
if stride == 2:
|
||||||
|
layer = InvertedResidual(
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=cfg.get("out_channels"),
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=cfg.get("mv_expand_ratio", 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
block.append(layer)
|
||||||
|
input_channel = cfg.get("out_channels")
|
||||||
|
|
||||||
|
transformer_dim = cfg["transformer_channels"]
|
||||||
|
ffn_dim = cfg.get("ffn_dim")
|
||||||
|
num_heads = cfg.get("num_heads", 4)
|
||||||
|
head_dim = transformer_dim // num_heads
|
||||||
|
|
||||||
|
if transformer_dim % head_dim != 0:
|
||||||
|
raise ValueError("Transformer input dimension should be divisible by head dimension. "
|
||||||
|
"Got {} and {}.".format(transformer_dim, head_dim))
|
||||||
|
|
||||||
|
block.append(MobileViTBlock(
|
||||||
|
in_channels=input_channel,
|
||||||
|
transformer_dim=transformer_dim,
|
||||||
|
ffn_dim=ffn_dim,
|
||||||
|
n_transformer_blocks=cfg.get("transformer_blocks", 1),
|
||||||
|
patch_h=cfg.get("patch_h", 2),
|
||||||
|
patch_w=cfg.get("patch_w", 2),
|
||||||
|
dropout=cfg.get("dropout", 0.1),
|
||||||
|
ffn_dropout=cfg.get("ffn_dropout", 0.0),
|
||||||
|
attn_dropout=cfg.get("attn_dropout", 0.1),
|
||||||
|
head_dim=head_dim,
|
||||||
|
conv_ksize=3
|
||||||
|
))
|
||||||
|
|
||||||
|
return nn.Sequential(*block), input_channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_parameters(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.Linear,)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.conv_1(x)
|
||||||
|
x = self.layer_1(x)
|
||||||
|
x = self.layer_2(x)
|
||||||
|
|
||||||
|
x = self.layer_3(x)
|
||||||
|
x = self.layer_4(x)
|
||||||
|
x = self.layer_5(x)
|
||||||
|
x = self.conv_1x1_exp(x)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_xx_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
|
||||||
|
config = get_config("xx_small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_x_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
|
||||||
|
config = get_config("x_small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def mobile_vit_small(num_classes: int = 1000):
|
||||||
|
# pretrain weight link
|
||||||
|
# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
|
||||||
|
config = get_config("small")
|
||||||
|
m = MobileViT(config, num_classes=num_classes)
|
||||||
|
return m
|
|
@ -0,0 +1,176 @@
|
||||||
|
def get_config(mode: str = "xxs") -> dict:
|
||||||
|
if mode == "xx_small":
|
||||||
|
mv2_exp_mult = 2
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 16,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 24,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 48,
|
||||||
|
"transformer_channels": 64,
|
||||||
|
"ffn_dim": 128,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2, # 8,
|
||||||
|
"patch_w": 2, # 8,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 64,
|
||||||
|
"transformer_channels": 80,
|
||||||
|
"ffn_dim": 160,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2, # 4,
|
||||||
|
"patch_w": 2, # 4,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 80,
|
||||||
|
"transformer_channels": 96,
|
||||||
|
"ffn_dim": 192,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
elif mode == "x_small":
|
||||||
|
mv2_exp_mult = 4
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 32,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 48,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 64,
|
||||||
|
"transformer_channels": 96,
|
||||||
|
"ffn_dim": 192,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 80,
|
||||||
|
"transformer_channels": 120,
|
||||||
|
"ffn_dim": 240,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 96,
|
||||||
|
"transformer_channels": 144,
|
||||||
|
"ffn_dim": 288,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
elif mode == "small":
|
||||||
|
mv2_exp_mult = 4
|
||||||
|
config = {
|
||||||
|
"layer1": {
|
||||||
|
"out_channels": 32,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 1,
|
||||||
|
"stride": 1,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"out_channels": 64,
|
||||||
|
"expand_ratio": mv2_exp_mult,
|
||||||
|
"num_blocks": 3,
|
||||||
|
"stride": 2,
|
||||||
|
"block_type": "mv2",
|
||||||
|
},
|
||||||
|
"layer3": { # 28x28
|
||||||
|
"out_channels": 96,
|
||||||
|
"transformer_channels": 144,
|
||||||
|
"ffn_dim": 288,
|
||||||
|
"transformer_blocks": 2,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer4": { # 14x14
|
||||||
|
"out_channels": 128,
|
||||||
|
"transformer_channels": 192,
|
||||||
|
"ffn_dim": 384,
|
||||||
|
"transformer_blocks": 4,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"layer5": { # 7x7
|
||||||
|
"out_channels": 160,
|
||||||
|
"transformer_channels": 240,
|
||||||
|
"ffn_dim": 480,
|
||||||
|
"transformer_blocks": 3,
|
||||||
|
"patch_h": 2,
|
||||||
|
"patch_w": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"mv_expand_ratio": mv2_exp_mult,
|
||||||
|
"num_heads": 4,
|
||||||
|
"block_type": "mobilevit",
|
||||||
|
},
|
||||||
|
"last_layer_exp_factor": 4,
|
||||||
|
"cls_dropout": 0.1
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
||||||
|
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
||||||
|
|
||||||
|
return config
|
|
@ -0,0 +1,37 @@
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataSet(Dataset):
|
||||||
|
"""自定义数据集"""
|
||||||
|
|
||||||
|
def __init__(self, images_path: list, images_class: list, transform=None):
|
||||||
|
self.images_path = images_path
|
||||||
|
self.images_class = images_class
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images_path)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img = Image.open(self.images_path[item])
|
||||||
|
# RGB为彩色图片,L为灰度图片
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
||||||
|
label = self.images_class[item]
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
# 官方实现的default_collate可以参考
|
||||||
|
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
||||||
|
images, labels = tuple(zip(*batch))
|
||||||
|
|
||||||
|
images = torch.stack(images, dim=0)
|
||||||
|
labels = torch.as_tensor(labels)
|
||||||
|
return images, labels
|
|
@ -0,0 +1,64 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from model import mobile_vit_small as create_model
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
|
|
||||||
|
#设置plt支持中文
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
img_size = 224
|
||||||
|
data_transform = transforms.Compose(
|
||||||
|
[transforms.Resize(int(img_size * 1.14)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
# load image
|
||||||
|
img_path = r"E:\Download\data\train\Acne and Rosacea Photos\acne-closed-comedo-8.jpg"
|
||||||
|
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||||
|
img = Image.open(img_path)
|
||||||
|
plt.imshow(img)
|
||||||
|
# [N, C, H, W]
|
||||||
|
img = data_transform(img)
|
||||||
|
# expand batch dimension
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# read class_indict
|
||||||
|
json_path = './class_indices.json'
|
||||||
|
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||||
|
|
||||||
|
with open(json_path, "r",encoding="utf-8") as f:
|
||||||
|
class_indict = json.load(f)
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = create_model(num_classes=24).to(device)
|
||||||
|
# load model weights
|
||||||
|
model_weight_path = "./best300_model_0.7302241690286009.pth"
|
||||||
|
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# predict class
|
||||||
|
output = torch.squeeze(model(img.to(device))).cpu()
|
||||||
|
predict = torch.softmax(output, dim=0)
|
||||||
|
predict_cla = torch.argmax(predict).numpy()
|
||||||
|
|
||||||
|
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||||
|
predict[predict_cla].numpy())
|
||||||
|
plt.title(print_res)
|
||||||
|
for i in range(len(predict)):
|
||||||
|
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||||
|
predict[i].numpy()))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,90 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from model import mobile_vit_small as create_model
|
||||||
|
|
||||||
|
class ImagePredictor:
|
||||||
|
def __init__(self, model_path, class_indices_path, img_size=224):
|
||||||
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.img_size = img_size
|
||||||
|
self.data_transform = transforms.Compose([
|
||||||
|
transforms.Resize(int(self.img_size * 1.14)),
|
||||||
|
transforms.CenterCrop(self.img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
# Load class indices
|
||||||
|
with open(class_indices_path, "r",encoding="utf-8") as f:
|
||||||
|
self.class_indict = json.load(f)
|
||||||
|
# Load model
|
||||||
|
self.model = self.load_model(model_path)
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
|
||||||
|
model = create_model(num_classes=3).to(self.device)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict_img(self, image_path):
|
||||||
|
# Load and transform image
|
||||||
|
assert os.path.exists(image_path), f"file: '{image_path}' does not exist."
|
||||||
|
img = Image.open(image_path).convert('RGB')
|
||||||
|
img = self.data_transform(img)
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# Predict class
|
||||||
|
with torch.no_grad():
|
||||||
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||||
|
probabilities = torch.softmax(output, dim=0)
|
||||||
|
top_prob, top_catid = torch.topk(probabilities, 5)
|
||||||
|
|
||||||
|
# Top 5 results
|
||||||
|
top5 = []
|
||||||
|
for i in range(top_prob.size(0)):
|
||||||
|
top5.append({
|
||||||
|
"name": self.class_indict[str(top_catid[i].item())],
|
||||||
|
"score": top_prob[i].item(),
|
||||||
|
"label": top_catid[i].item()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Results dictionary
|
||||||
|
|
||||||
|
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
||||||
|
|
||||||
|
return results
|
||||||
|
def predict(self, np_image):
|
||||||
|
# Convert numpy image to PIL image
|
||||||
|
img = Image.fromarray(np_image).convert('RGB')
|
||||||
|
|
||||||
|
# Transform image
|
||||||
|
img = self.data_transform(img)
|
||||||
|
img = torch.unsqueeze(img, dim=0)
|
||||||
|
|
||||||
|
# Predict class
|
||||||
|
with torch.no_grad():
|
||||||
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
||||||
|
probabilities = torch.softmax(output, dim=0)
|
||||||
|
top_prob, top_catid = torch.topk(probabilities, 1)
|
||||||
|
|
||||||
|
# Top 5 results
|
||||||
|
top5 = []
|
||||||
|
for i in range(top_prob.size(0)):
|
||||||
|
top5.append({
|
||||||
|
"name": self.class_indict[str(top_catid[i].item())],
|
||||||
|
"score": top_prob[i].item(),
|
||||||
|
"label": top_catid[i].item()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Results dictionary
|
||||||
|
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
||||||
|
# result = predictor.predict("../tulip.jpg")
|
||||||
|
# print(result)
|
|
@ -0,0 +1,135 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from my_dataset import MyDataSet
|
||||||
|
from model import mobile_vit_xx_small as create_model
|
||||||
|
from utils import read_split_data, train_one_epoch, evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if os.path.exists("./weights") is False:
|
||||||
|
os.makedirs("./weights")
|
||||||
|
|
||||||
|
tb_writer = SummaryWriter()
|
||||||
|
|
||||||
|
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
||||||
|
|
||||||
|
img_size = 224
|
||||||
|
data_transform = {
|
||||||
|
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
|
||||||
|
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
|
||||||
|
|
||||||
|
# 实例化训练数据集
|
||||||
|
train_dataset = MyDataSet(images_path=train_images_path,
|
||||||
|
images_class=train_images_label,
|
||||||
|
transform=data_transform["train"])
|
||||||
|
|
||||||
|
# 实例化验证数据集
|
||||||
|
val_dataset = MyDataSet(images_path=val_images_path,
|
||||||
|
images_class=val_images_label,
|
||||||
|
transform=data_transform["val"])
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||||
|
print('Using {} dataloader workers every process'.format(nw))
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=nw,
|
||||||
|
collate_fn=train_dataset.collate_fn)
|
||||||
|
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=nw,
|
||||||
|
collate_fn=val_dataset.collate_fn)
|
||||||
|
|
||||||
|
model = create_model(num_classes=args.num_classes).to(device)
|
||||||
|
|
||||||
|
if args.weights != "":
|
||||||
|
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
|
||||||
|
weights_dict = torch.load(args.weights, map_location=device)
|
||||||
|
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
|
||||||
|
# 删除有关分类类别的权重
|
||||||
|
for k in list(weights_dict.keys()):
|
||||||
|
if "classifier" in k:
|
||||||
|
del weights_dict[k]
|
||||||
|
print(model.load_state_dict(weights_dict, strict=False))
|
||||||
|
|
||||||
|
if args.freeze_layers:
|
||||||
|
for name, para in model.named_parameters():
|
||||||
|
# 除head外,其他权重全部冻结
|
||||||
|
if "classifier" not in name:
|
||||||
|
para.requires_grad_(False)
|
||||||
|
else:
|
||||||
|
print("training {}".format(name))
|
||||||
|
|
||||||
|
pg = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
|
||||||
|
|
||||||
|
best_acc = 0.
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
# train
|
||||||
|
train_loss, train_acc = train_one_epoch(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
data_loader=train_loader,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch)
|
||||||
|
|
||||||
|
# validate
|
||||||
|
val_loss, val_acc = evaluate(model=model,
|
||||||
|
data_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch)
|
||||||
|
|
||||||
|
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
||||||
|
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
||||||
|
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
||||||
|
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
||||||
|
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
||||||
|
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
||||||
|
|
||||||
|
if val_acc > best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
torch.save(model.state_dict(), "./weights/best_model.pth")
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), "./weights/latest_model.pth")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--num_classes', type=int, default=5)
|
||||||
|
parser.add_argument('--epochs', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0002)
|
||||||
|
|
||||||
|
# 数据集所在根目录
|
||||||
|
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||||||
|
parser.add_argument('--data-path', type=str,
|
||||||
|
default="/data/flower_photos")
|
||||||
|
|
||||||
|
# 预训练权重路径,如果不想载入就设置为空字符
|
||||||
|
parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
|
||||||
|
help='initial weights path')
|
||||||
|
# 是否冻结权重
|
||||||
|
parser.add_argument('--freeze-layers', type=bool, default=False)
|
||||||
|
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
main(opt)
|
|
@ -0,0 +1,155 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
This layer applies a multi-head self- or cross-attention as described in
|
||||||
|
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||||
|
num_heads (int): Number of heads in multi-head attention
|
||||||
|
attn_dropout (float): Attention dropout. Default: 0.0
|
||||||
|
bias (bool): Use bias or not. Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||||
|
and :math:`C_{in}` is input embedding dim
|
||||||
|
- Output: same shape as the input
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if embed_dim % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
||||||
|
self.__class__.__name__, embed_dim, num_heads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||||
|
|
||||||
|
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||||
|
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def forward(self, x_q: Tensor) -> Tensor:
|
||||||
|
# [N, P, C]
|
||||||
|
b_sz, n_patches, in_channels = x_q.shape
|
||||||
|
|
||||||
|
# self-attention
|
||||||
|
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
||||||
|
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
||||||
|
|
||||||
|
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
||||||
|
qkv = qkv.transpose(1, 3).contiguous()
|
||||||
|
|
||||||
|
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
||||||
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||||
|
|
||||||
|
query = query * self.scaling
|
||||||
|
|
||||||
|
# [N h, P, c] -> [N, h, c, P]
|
||||||
|
key = key.transpose(-1, -2)
|
||||||
|
|
||||||
|
# QK^T
|
||||||
|
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
||||||
|
attn = torch.matmul(query, key)
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
# weighted sum
|
||||||
|
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
||||||
|
out = torch.matmul(attn, value)
|
||||||
|
|
||||||
|
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
||||||
|
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
||||||
|
out = self.out_proj(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
||||||
|
Args:
|
||||||
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
||||||
|
ffn_latent_dim (int): Inner dimension of the FFN
|
||||||
|
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
||||||
|
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||||
|
dropout (float): Dropout rate. Default: 0.0
|
||||||
|
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||||
|
and :math:`C_{in}` is input embedding dim
|
||||||
|
- Output: same shape as the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
ffn_latent_dim: int,
|
||||||
|
num_heads: Optional[int] = 8,
|
||||||
|
attn_dropout: Optional[float] = 0.0,
|
||||||
|
dropout: Optional[float] = 0.0,
|
||||||
|
ffn_dropout: Optional[float] = 0.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
attn_unit = MultiHeadAttention(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_mha = nn.Sequential(
|
||||||
|
nn.LayerNorm(embed_dim),
|
||||||
|
attn_unit,
|
||||||
|
nn.Dropout(p=dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_ffn = nn.Sequential(
|
||||||
|
nn.LayerNorm(embed_dim),
|
||||||
|
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=ffn_dropout),
|
||||||
|
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||||
|
nn.Dropout(p=dropout)
|
||||||
|
)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.ffn_dim = ffn_latent_dim
|
||||||
|
self.ffn_dropout = ffn_dropout
|
||||||
|
self.std_dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# multi-head attention
|
||||||
|
res = x
|
||||||
|
x = self.pre_norm_mha(x)
|
||||||
|
x = x + res
|
||||||
|
|
||||||
|
# feed forward network
|
||||||
|
x = x + self.pre_norm_ffn(x)
|
||||||
|
return x
|
|
@ -0,0 +1,56 @@
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
in_channels = 32
|
||||||
|
patch_h = 2
|
||||||
|
patch_w = 2
|
||||||
|
num_patch_h = 16
|
||||||
|
num_patch_w = 16
|
||||||
|
num_patches = num_patch_h * num_patch_w
|
||||||
|
patch_area = patch_h * patch_w
|
||||||
|
|
||||||
|
|
||||||
|
def official(x: torch.Tensor):
|
||||||
|
# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def my_self(x: torch.Tensor):
|
||||||
|
# [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
|
||||||
|
# [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
|
||||||
|
x = x.transpose(3, 4)
|
||||||
|
# [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
||||||
|
x = x.reshape(batch_size, in_channels, num_patches, patch_area)
|
||||||
|
# [B, C, N, P] -> [B, P, N, C]
|
||||||
|
x = x.transpose(1, 3)
|
||||||
|
# [B, P, N, C] -> [BP, N, C]
|
||||||
|
x = x.reshape(batch_size * patch_area, num_patches, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||||
|
print(torch.equal(official(t), my_self(t)))
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
official(t)
|
||||||
|
print(f"official time: {time.time() - t1}")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
my_self(t)
|
||||||
|
print(f"self time: {time.time() - t1}")
|
|
@ -0,0 +1,179 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def read_split_data(root: str, val_rate: float = 0.2):
|
||||||
|
random.seed(0) # 保证随机结果可复现
|
||||||
|
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
||||||
|
|
||||||
|
# 遍历文件夹,一个文件夹对应一个类别
|
||||||
|
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
||||||
|
# 排序,保证各平台顺序一致
|
||||||
|
flower_class.sort()
|
||||||
|
# 生成类别名称以及对应的数字索引
|
||||||
|
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
||||||
|
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
||||||
|
with open('class_indices.json', 'w') as json_file:
|
||||||
|
json_file.write(json_str)
|
||||||
|
|
||||||
|
train_images_path = [] # 存储训练集的所有图片路径
|
||||||
|
train_images_label = [] # 存储训练集图片对应索引信息
|
||||||
|
val_images_path = [] # 存储验证集的所有图片路径
|
||||||
|
val_images_label = [] # 存储验证集图片对应索引信息
|
||||||
|
every_class_num = [] # 存储每个类别的样本总数
|
||||||
|
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
||||||
|
# 遍历每个文件夹下的文件
|
||||||
|
for cla in flower_class:
|
||||||
|
cla_path = os.path.join(root, cla)
|
||||||
|
# 遍历获取supported支持的所有文件路径
|
||||||
|
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
||||||
|
if os.path.splitext(i)[-1] in supported]
|
||||||
|
# 排序,保证各平台顺序一致
|
||||||
|
images.sort()
|
||||||
|
# 获取该类别对应的索引
|
||||||
|
image_class = class_indices[cla]
|
||||||
|
# 记录该类别的样本数量
|
||||||
|
every_class_num.append(len(images))
|
||||||
|
# 按比例随机采样验证样本
|
||||||
|
val_path = random.sample(images, k=int(len(images) * val_rate))
|
||||||
|
|
||||||
|
for img_path in images:
|
||||||
|
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
||||||
|
val_images_path.append(img_path)
|
||||||
|
val_images_label.append(image_class)
|
||||||
|
else: # 否则存入训练集
|
||||||
|
train_images_path.append(img_path)
|
||||||
|
train_images_label.append(image_class)
|
||||||
|
|
||||||
|
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
||||||
|
print("{} images for training.".format(len(train_images_path)))
|
||||||
|
print("{} images for validation.".format(len(val_images_path)))
|
||||||
|
assert len(train_images_path) > 0, "number of training images must greater than 0."
|
||||||
|
assert len(val_images_path) > 0, "number of validation images must greater than 0."
|
||||||
|
|
||||||
|
plot_image = False
|
||||||
|
if plot_image:
|
||||||
|
# 绘制每种类别个数柱状图
|
||||||
|
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
||||||
|
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
||||||
|
plt.xticks(range(len(flower_class)), flower_class)
|
||||||
|
# 在柱状图上添加数值标签
|
||||||
|
for i, v in enumerate(every_class_num):
|
||||||
|
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||||
|
# 设置x坐标
|
||||||
|
plt.xlabel('image class')
|
||||||
|
# 设置y坐标
|
||||||
|
plt.ylabel('number of images')
|
||||||
|
# 设置柱状图的标题
|
||||||
|
plt.title('flower class distribution')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return train_images_path, train_images_label, val_images_path, val_images_label
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data_loader_image(data_loader):
|
||||||
|
batch_size = data_loader.batch_size
|
||||||
|
plot_num = min(batch_size, 4)
|
||||||
|
|
||||||
|
json_path = './class_indices.json'
|
||||||
|
assert os.path.exists(json_path), json_path + " does not exist."
|
||||||
|
json_file = open(json_path, 'r')
|
||||||
|
class_indices = json.load(json_file)
|
||||||
|
|
||||||
|
for data in data_loader:
|
||||||
|
images, labels = data
|
||||||
|
for i in range(plot_num):
|
||||||
|
# [C, H, W] -> [H, W, C]
|
||||||
|
img = images[i].numpy().transpose(1, 2, 0)
|
||||||
|
# 反Normalize操作
|
||||||
|
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
||||||
|
label = labels[i].item()
|
||||||
|
plt.subplot(1, plot_num, i+1)
|
||||||
|
plt.xlabel(class_indices[str(label)])
|
||||||
|
plt.xticks([]) # 去掉x轴的刻度
|
||||||
|
plt.yticks([]) # 去掉y轴的刻度
|
||||||
|
plt.imshow(img.astype('uint8'))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def write_pickle(list_info: list, file_name: str):
|
||||||
|
with open(file_name, 'wb') as f:
|
||||||
|
pickle.dump(list_info, f)
|
||||||
|
|
||||||
|
|
||||||
|
def read_pickle(file_name: str) -> list:
|
||||||
|
with open(file_name, 'rb') as f:
|
||||||
|
info_list = pickle.load(f)
|
||||||
|
return info_list
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||||
|
model.train()
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
sample_num = 0
|
||||||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||||
|
for step, data in enumerate(data_loader):
|
||||||
|
images, labels = data
|
||||||
|
sample_num += images.shape[0]
|
||||||
|
|
||||||
|
pred = model(images.to(device))
|
||||||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||||||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||||
|
|
||||||
|
loss = loss_function(pred, labels.to(device))
|
||||||
|
loss.backward()
|
||||||
|
accu_loss += loss.detach()
|
||||||
|
|
||||||
|
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||||
|
accu_loss.item() / (step + 1),
|
||||||
|
accu_num.item() / sample_num)
|
||||||
|
|
||||||
|
if not torch.isfinite(loss):
|
||||||
|
print('WARNING: non-finite loss, ending training ', loss)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(model, data_loader, device, epoch):
|
||||||
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
||||||
|
accu_loss = torch.zeros(1).to(device) # 累计损失
|
||||||
|
|
||||||
|
sample_num = 0
|
||||||
|
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||||
|
for step, data in enumerate(data_loader):
|
||||||
|
images, labels = data
|
||||||
|
sample_num += images.shape[0]
|
||||||
|
|
||||||
|
pred = model(images.to(device))
|
||||||
|
pred_classes = torch.max(pred, dim=1)[1]
|
||||||
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
||||||
|
|
||||||
|
loss = loss_function(pred, labels.to(device))
|
||||||
|
accu_loss += loss
|
||||||
|
|
||||||
|
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
||||||
|
accu_loss.item() / (step + 1),
|
||||||
|
accu_num.item() / sample_num)
|
||||||
|
|
||||||
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
|
@ -0,0 +1,51 @@
|
||||||
|
import cv2
|
||||||
|
import dlib
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from predict_api import ImagePredictor
|
||||||
|
|
||||||
|
# Initialize camera and face detector
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
detector = dlib.get_frontal_face_detector()
|
||||||
|
|
||||||
|
# Initialize ImagePredictor
|
||||||
|
predictor = ImagePredictor(model_path="best_model_'0.8998410174880763'.pth", class_indices_path="./class_indices.json")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Capture frame-by-frame
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
# Convert the image from BGR color (which OpenCV uses) to RGB color
|
||||||
|
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Perform face detection
|
||||||
|
faces = detector(rgb_image)
|
||||||
|
|
||||||
|
# Loop through each face in this frame
|
||||||
|
for rect in faces:
|
||||||
|
# Get the bounding box coordinates
|
||||||
|
x1, y1, x2, y2 = rect.left(), rect.top(), rect.right(), rect.bottom()
|
||||||
|
|
||||||
|
# Crop the face from the frame
|
||||||
|
face_image = rgb_image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
# Use ImagePredictor to predict the class of this face
|
||||||
|
result = predictor.predict(face_image)
|
||||||
|
|
||||||
|
# Draw a rectangle around the face
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Display the class name and score
|
||||||
|
cv2.putText(frame, f"{result['result'][0]['name']}: {round(result['result'][0]['score'],4)}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
|
||||||
|
|
||||||
|
# Display the resulting frame
|
||||||
|
cv2.imshow('Video', frame)
|
||||||
|
|
||||||
|
# Exit loop if 'q' is pressed
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
# When everything is done, release the capture
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
Loading…
Reference in New Issue