【发布时间】:2021-06-01 18:42:28
【问题描述】:
我正在编写一个自定义数据加载器,而返回的值让我感到困惑。
import torch
import torch.nn as nn
import numpy as np
import torch.utils.data as data_utils
class TestDataset:
def __init__(self):
self.db = np.random.randn(20, 3, 60, 60)
def __getitem__(self, idx):
img = self.db[idx]
return img, img.shape[1:]
def __len__(self):
return self.db.shape[0]
if __name__ == '__main__':
test_dataset = TestDataset()
test_dataloader = data_utils.DataLoader(test_dataset,
batch_size=1,
num_workers=4,
shuffle=False, \
pin_memory=True
)
for i, (imgs, sizes) in enumerate(test_dataloader):
print(imgs.size()) # torch.Size([1, 3, 60, 60])
print(sizes) # [tensor([60]), tensor([60])]
break
为什么“sizes”返回一个长度为 2 的列表?我认为它应该是“torch.Size([1, 2])”,它表示图像的高度和宽度(1 batch_size)。
更进一步,返回列表的长度是否应该与batch_size相同?如果我想得到尺寸,我必须写“sizes = [sizes[0][0].item(), sizes[1][0].item()]”。这让我很困惑。
感谢您的宝贵时间。
【问题讨论】:
标签: python pytorch dataloader