【问题标题】:I want to see data from torch.utils.data.DataLoader. How Can I?我想查看来自 torch.utils.data.DataLoader 的数据。我怎样才能?
【发布时间】:2021-12-04 16:58:24
【问题描述】:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

train_set = torchvision.datasets.MNIST(root = './data/MNIST',train = True,download = True,\transform = transfroms.Compose([transfroms.ToTensor()])

print(len(train_set))
# 60000

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
print(len(train_loader))
# 600

好像是因为batch_size,train_loader的长度减少了。

我认为一批中有 100 个张量和一个分类。 我只想看看它的元素或形状。我能怎么做? 还有,

### Model Omitted ###
model = ConvNet().to(device) 
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

for epoch in range(5): 
    avg_cost = 0
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad() 
        hypothesis = model(data) 
        cost = criterion(hypothesis, target) 
        cost.backward() 
        optimizer.step() 
        avg_cost += cost / len(train_loader)

    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

我认为每个 epoch 的训练训练有 60,000 个张量,对吗?那么我认为avg_cost应该除以60,000,而不是600(这是len(train_loader))......我错了吗?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    您可以使用下面的代码从trainloader 获取一批火车数据,您可以轻松检查它的形状。我希望这可能有助于得到你想要的。

    batch= iter(trainloader)
    images, labels = batch.next()
    
    print(images.shape)
    # torch.Size([num_samples, in_channels, H, W])
    
    print(labels.shape)
    

    【讨论】:

    • 我可以看到它的形状。但是,我怎样才能看到它的结构和元素呢?我的意思是,我想检查,例如batch1、batch2、batch3.....(也许对于batch,有100个张量)和batch中的元素
    猜你喜欢
    • 2018-04-09
    • 2016-07-20
    • 2021-11-25
    • 2020-09-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-10-10
    • 1970-01-01
    相关资源
    最近更新 更多