【问题标题】:How to get the filename of a sample from a DataLoader?如何从 DataLoader 获取样本的文件名?
【发布时间】:2019-11-04 01:05:53
【问题描述】:

我需要编写一个文件,其中包含我训练的卷积神经网络的数据测试结果。数据包括语音数据收集。文件格式需要是“文件名,预测”,但我很难提取文件名。我像这样加载数据:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

我正在尝试按如下方式写入文件:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

os.listdir(TESTH_DATA_PATH + "/all")[i]的问题在于它与test_loader的加载文件顺序不同步。我能做什么?

【问题讨论】:

    标签: python machine-learning pytorch torchvision


    【解决方案1】:

    嗯,这取决于您的Dataset 是如何实现的。例如,在torchvision.datasets.MNIST(...) 的情况下,您无法检索文件名,因为没有单个样本的文件名(MNIST 样本为loaded in a different way)。

    由于您没有展示您的Dataset 实现,我将告诉您如何使用torchvision.datasets.ImageFolder(...)(或任何torchvision.datasets.DatasetFolder(...))来完成:

    f = open("test_y", "w")
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader, 0):
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            sample_fname, _ = test_loader.dataset.samples[i]
            f.write("{}, {}\n".format(sample_fname, predicted.item()))
    f.close()
    

    您可以看到在__getitem__(self, index) 期间检索文件的路径,特别是here

    如果您实现了自己的Dataset(并且可能希望支持shufflebatch_size > 1),那么我将在__getitem__(...) 调用中返回sample_fname 并执行以下操作:

    for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
        # [...]
    

    这样你就不需要关心shuffle。如果 batch_size 大于 1,则需要将循环的内容更改为更通用的内容,例如:

    f = open("test_y", "w")
    for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
        outputs = model(images)
        pred = torch.max(outputs, 1)[1]
        f.write("\n".join([
            ", ".join(x)
            for x in zip(map(str, pred.cpu().tolist()), samples_fname)
        ]) + "\n")
    f.close()
    

    【讨论】:

    • 感谢您的提示!我可以从 datasets.ImageFolder.samples[i][0] 读取文件名列表
    【解决方案2】:

    一般情况下,DataLoader 会为您提供其内部数据集中的批次。

    @Barriel 在单/多标签分类问题的情况下提到,DataLoader 没有图像文件名,只有表示图像的张量和类/标签。

    但是,DataLoader 构造函数在加载对象时可能会占用一些小东西(与数据集一起,您可以根据需要打包目标/标签和文件名),甚至是数据框

    这样,DataLoader 可能会以某种方式获取您需要的内容。

    【讨论】:

      【解决方案3】:

      如果你使用 PyCharm 或任何有调试工具的 IDE,让我们用它来看看你的 data_loader,希望你能看到一个文件名列表,就像我的例子一样。

      就我而言, 我的 data_loader 是由 mmsegmentation 创建的。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2022-01-18
        • 2011-03-04
        • 1970-01-01
        • 1970-01-01
        • 2012-01-26
        • 2021-11-26
        • 1970-01-01
        • 2010-10-17
        相关资源
        最近更新 更多