【发布时间】:2021-04-13 10:37:57
【问题描述】:
在下面的编码示例中,我无法理解为什么输出张量 grid 的形状为
3,28,280。我理解为什么它的高度为28,宽度为280,而不是3。似乎从沿轴0 的所有3 个28x280 阵列上运行plt.imshow() 来看,它们是相同的副本,因为打印其中任何一个给出我想要的形象。
另外,我不明白为什么我可以将grid 作为参数传递给plt.imshow(),因为它应该采用2D 数组,而不是像grid 那样的3D 数组。
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
train_set = torchvision.datasets.FashionMNIST(
root = './pytorch_obj_classifier/data/FashionMNIST',
train = True,
download = True,
transform = transforms.Compose([
transforms.ToTensor()
])
)
sample = next(iter(train_loader))
image,label = sample
print(image.shape)
grid = torchvision.utils.make_grid(image,padding=0, nrow=10)
print(grid.shape)
plt.figure(figsize=(15,15))
grid = np.transpose(grid,(1,2,0))
grid1 = grid[:,:,0]
grid2 = grid[:,:,1]
grid3 = grid[:,:,2]
plt.imshow(grid1,cmap = 'gray')
plt.imshow(grid2,cmap = 'gray')
plt.imshow(grid3,cmap = 'gray')
plt.imshow(grid,cmap = 'gray')
【问题讨论】:
标签: python python-3.x pytorch torchvision