【发布时间】:2020-10-06 09:08:48
【问题描述】:
在我的自定义数据集中,我想将transforms.Compose() 应用于 NumPy 数组。
我的图像采用 NumPy 数组格式,形状为 (num_samples, width, height, channels)。
如何将以下转换应用于完整的 numpy 数组?
img_transform = transforms.Compose([
transforms.Scale((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
])
我的尝试以多个错误结束,因为转换接受 PIL 图像而不是 4 维 NumPy 数组。
from torchvision import transforms
import numpy as np
import torch
img_transform = transforms.Compose([
transforms.Scale((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
])
a = np.random.randint(0,256, (299,299,3))
print(a.shape)
img_transform(a)
【问题讨论】:
标签: pytorch torchvision