【问题标题】:Problem getting the labels of training-set获取训练集标签的问题
【发布时间】:2019-10-25 08:55:49
【问题描述】:

我使用train_test_split 函数将我的数据分为X_trainX_testy_trainy_test,然后使用utils.data.DataLoader 将其提供给我的 CNN,但问题是我这样做了不知道如何访问我的标签张量以制作混淆矩阵并将它们与我的预测张量进行比较。我知道这是一个基本问题,但无论如何感谢您的帮助。

X_train, X_test, y_train, y_test = train_test_split(faces, emotions, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=41)

我用过

train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)

用于将数据提供给我的网络 您似乎可以通过在您的 train_set 之后键入目标属性来访问您的标签,例如 train_set.targets,但它对我不起作用。如何获取我的标签?

【问题讨论】:

    标签: python label classification conv-neural-network dataloader


    【解决方案1】:

    PyTorch 的 DataLoader 对象大致是这样使用的:

    for i, (inputs, labels) in enumerate(dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)
    
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
    

    一般来说,我建议使用两个 DataLoader,一个用于训练,一个用于测试/验证。由于您想创建一个混淆矩阵,您可以通过您的 numpy 数组 y_train 和您的预测 preds 来访问您的标签,例如通过在循环内将它们连接到一个 numpy 数组。

    有关如何使用 DataLoader 的更多信息,我建议查看这个非常好的教程: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

    https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-07-29
      • 2021-05-20
      • 2020-06-17
      • 2016-07-08
      • 2012-01-08
      • 1970-01-01
      • 2019-01-31
      相关资源
      最近更新 更多