【问题标题】:How to get only specific classes from PyTorch's FashionMNIST dataset?如何从 PyTorch Fashion MNIST 数据集中仅获取特定类?
【发布时间】:2022-01-12 02:45:17
【问题描述】:

FashionMNIST 数据集有 10 个不同的输出类。我怎样才能得到这个数据集的一个子集,只有特定的类?就我而言,我只想要运动鞋、套头衫、凉鞋和衬衫类的图像(它们的类分别是 7、2、5 和 6)。

这就是我加载数据集的方式。

train_dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())

我遵循的方法如下。 逐个遍历数据集,然后将返回的元组中的第一个元素(即类)与我需要的类进行比较。我被困在这里。如果返回值为 true,我如何将此观察结果附加/添加到空数据集中?

sneaker = 0
pullover = 0
sandal = 0
shirt = 0
for i in range(60000):
    if train_dataset_full[i][1] == 7:
        sneaker += 1
    elif train_dataset_full[i][1] == 2:
        pullover += 1
    elif train_dataset_full[i][1] == 5:
        sandal += 1
    elif train_dataset_full[i][1] == 6:
        shirt += 1

现在,代替sneaker += 1pullover += 1sandal += 1shirt += 1,我想做类似empty_dataset.append(train_dataset_full[i]) 或类似的事情。

如果上述方法不正确,请提出其他方法。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    终于找到答案了。

    dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())
    # Selecting classes 7, 2, 5 and 6
    idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
    dataset_full.targets = dataset_full.targets[idx]
    dataset_full.data = dataset_full.data[idx]
    

    【讨论】:

      【解决方案2】:

      您可以使用列表推导来匹配标签。例如

      idx = dataset.train_labels == 1
      dataset.train_labels = dataset.train_labels[idx]
      

      这将只选择你想要的标签。

      【讨论】:

      • 是的,我最终找到了这个。但是,train_labels 应该替换为目标。
      【解决方案3】:

      我无法使用dataset.train_labelsdataset.data,因此我使用DataLoader 加载了带有所有标签的完整数据集,然后在训练步骤中选择了所需的标签。就我而言,标签是 3 和 4。不确定我的方法是否正确。

      for epoch in range(2):  # loop over the dataset multiple times
          running_loss = 0.0
          for i, data in enumerate(train_dataloader, 0):
              # get the inputs; data is a list of [inputs, labels]
              inputs, labels = data
              if (labels==4)|(labels==3):
      
              # zero the parameter gradients
                  optimizer.zero_grad()
      
              # forward + backward + optimize
                  outputs = net(inputs)
                  loss = criterion(outputs, labels)
                  loss.backward()
                  optimizer.step()
      
              # print statistics
                  running_loss += loss.item()
                  if i % 2000 == 1999:    # print every 2000 mini-batches
                      print('[%d, %5d] loss: %.3f' %
                            (epoch + 1, i + 1, running_loss / 2000))
                      running_loss = 0.0
      
      print('Finished Training')
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2023-02-15
        • 1970-01-01
        • 2019-04-10
        • 2018-12-14
        • 1970-01-01
        • 1970-01-01
        • 2021-10-19
        相关资源
        最近更新 更多