39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
|
from PIL import Image
|
|||
|
import torch
|
|||
|
from torch.utils.data import Dataset
|
|||
|
|
|||
|
|
|||
|
class MyDataSet(Dataset):
|
|||
|
"""自定义数据集"""
|
|||
|
|
|||
|
def __init__(self, images_path: list, images_class: list, transform=None):
|
|||
|
self.images_path = images_path
|
|||
|
self.images_class = images_class
|
|||
|
self.transform = transform
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.images_path)
|
|||
|
|
|||
|
def __getitem__(self, item):
|
|||
|
img = Image.open(self.images_path[item])
|
|||
|
# RGB为彩色图片,L为灰度图片
|
|||
|
if img.mode != 'RGB':
|
|||
|
# img = img.convert('RGB')
|
|||
|
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
|||
|
label = self.images_class[item]
|
|||
|
|
|||
|
if self.transform is not None:
|
|||
|
img = self.transform(img)
|
|||
|
|
|||
|
return img, label
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def collate_fn(batch):
|
|||
|
# 官方实现的default_collate可以参考
|
|||
|
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
|||
|
images, labels = tuple(zip(*batch))
|
|||
|
|
|||
|
images = torch.stack(images, dim=0)
|
|||
|
labels = torch.as_tensor(labels)
|
|||
|
return images, labels
|