【问题标题】:Tensorflow version of Pytorch TransformsTensorFlow 版本的 Pytorch 变换
【发布时间】:2021-09-03 22:00:54
【问题描述】:

我有以下代码用于在 模型中执行推理之前准备图像:

def image_loader(transform, image_name):
    image = Image.open(image_name)
    #transform
    image = transform(image).float()
    image = torch.tensor(image)
    image = image.unsqueeze(0)
    return image

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

我已将模型转换为 Tensorflow 模型,但是,我不确定如何在推理之前对图像进行类似的转换,因为似乎没有 等效项。有什么建议吗?

【问题讨论】:

    标签: pytorch tensorflow keras python tensorflow machine-learning keras pytorch


    【解决方案1】:

    这里有一些指针,在 你有

    from torchvision import transforms
    from PIL import Image 
    import torch 
    
    def image_loader(transform, image_name):
        image = Image.open(image_name).convert('RGB')
        image = transform(image).float()
        image = torch.tensor(image)
        image = image.unsqueeze(0)
        return image
    
    data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # check: visualize 
    i = image_loader(data_transforms, '/content/1.png')
    i.shape
    
    plt.figure(figsize=(25,10))
    subplot(121); imshow(np.array(i[0]).transpose(1, 2, 0)); 
    

    而在,你可以通过如下方式实现

    def transform(image, mean, std):
        for channel in range(3):
            image[:,:,channel] = (image[:,:,channel] - mean[channel]) / std[channel]
        return image
    
    def image_loader(image_name):
        image = Image.open(image_name).convert('RGB')
        image = transform(np.array(image)/255, 
                           mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
        image = tf.cast(image, tf.float32)
        image = tf.expand_dims(image, 0)
        return image 
    
    # check: visualize 
    i = image_loader('/content/1.png')
    i.shape 
    
    plt.figure(figsize=(25,10))
    subplot(121); imshow(i[0]); 
    

    这应该输出相同。注意,在第二种情况下,我们定义了transform函数,来自另一个OP,here,没关系,但是,你也可以检查@987654324 @,详见this answer。仅供参考,我认为您在 pytorch 中对输入进行了两次标准化,在 transform.ToTensor 中从 [0..255] 转换为 [0..1],接下来是 transform.Normalize,我认为您应该考虑这一点。

    【讨论】:

      猜你喜欢
      • 2019-10-03
      • 2022-08-19
      • 2020-10-19
      • 1970-01-01
      • 2021-09-08
      • 2021-01-11
      • 2022-10-24
      • 2018-09-22
      • 2021-12-29
      相关资源
      最近更新 更多