【问题标题】:Plot the transformed (augmented) images in pytorch在 pytorch 中绘制转换(增强)图像
【发布时间】:2021-11-05 16:39:40
【问题描述】:

我想使用一种图像增强技术(例如旋转或水平翻转)并将其应用于 CIFAR-10 数据集的一些图像并在 PyTorch 中绘制它们。

我知道我们可以使用以下代码来增强图像:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

然后当我想加载 Cifar10 数据集时,我使用了上面的转换:

train_set = CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train'])

据我所知,当使用此代码时,所有 CIFAR10 数据集都会被转换。

问题

我的问题是如何对数据集中的某些图像使用数据转换或增强技术并绘制它们?例如 10 张图像及其增强图像。

【问题讨论】:

    标签: python machine-learning deep-learning pytorch image-augmentation


    【解决方案1】:

    使用此代码时,所有 CIFAR10 数据集都会被转换

    实际上,只有当用户通过__getitem__ 函数或数据加载器获取数据集中的图像时,才会调用转换管道。所以在这个时间点,train_set 不包含增强图像,它们是动态转换的。


    您将需要构建另一个没有扩充的数据集。

    >>> non_augmented = CIFAR10(
    ...     root='./data/',
    ...     train=True,
    ...     download=True)
    
    >>> train_set = CIFAR10(
    ...     root='./data/',
    ...     train=True,
    ...     download=True,
    ...     transform=data_transforms)
    

    将一些图像堆叠在一起:

    >>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                            *[train_set[i][0] for i in range(10)]))
    
    >>> imgs.shape
    torch.Size([20, 3, 32, 32])
    

    然后torchvision.utils.make_grid 可用于创建所需的布局:

    >>> grid = torchvision.utils.make_grid(imgs, nrow=10)
    

    你有它!

    >>> transforms.ToPILImage()(grid)
    

    【讨论】:

    • 非常感谢。
    猜你喜欢
    • 2021-03-05
    • 1970-01-01
    • 1970-01-01
    • 2019-08-06
    • 2020-08-05
    • 2021-06-06
    • 1970-01-01
    • 2020-10-28
    • 2020-09-05
    相关资源
    最近更新 更多