这里有一些指针,在 pytorch 你有
from torchvision import transforms
from PIL import Image
import torch
def image_loader(transform, image_name):
image = Image.open(image_name).convert('RGB')
image = transform(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# check: visualize
i = image_loader(data_transforms, '/content/1.png')
i.shape
plt.figure(figsize=(25,10))
subplot(121); imshow(np.array(i[0]).transpose(1, 2, 0));
而在tensorflow,你可以通过如下方式实现
def transform(image, mean, std):
for channel in range(3):
image[:,:,channel] = (image[:,:,channel] - mean[channel]) / std[channel]
return image
def image_loader(image_name):
image = Image.open(image_name).convert('RGB')
image = transform(np.array(image)/255,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, 0)
return image
# check: visualize
i = image_loader('/content/1.png')
i.shape
plt.figure(figsize=(25,10))
subplot(121); imshow(i[0]);
这应该输出相同。注意,在第二种情况下,我们定义了transform函数,来自另一个OP,here,没关系,但是,你也可以检查@987654324 @,详见this answer。仅供参考,我认为您在 pytorch 中对输入进行了两次标准化,在 transform.ToTensor 中从 [0..255] 转换为 [0..1],接下来是 transform.Normalize,我认为您应该考虑这一点。