【问题标题】:IndexError: list index out of range in prediction of imagesIndexError:图像预测中的列表索引超出范围
【发布时间】:2021-07-12 00:47:38
【问题描述】:

我正在对写入所有类名称的图像进行预测,在测试文件夹中,我有 20 张图像。请给我一些提示,为什么我会出错?我们如何检查模型的索引?

代码

import numpy as np
import sys, random
import torch
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import glob

# Paths for image directory and model
IMDIR = './test'
MODEL = 'checkpoint/resnet18/Monday_31_May_2021_21h_25m_05s/resnet18-1000-regular.pth'

# Load the model for testing
model = models.resnet18()

model.named_children()

torch.save(model.state_dict, MODEL)
model.eval()

# Class labels for prediction
class_names = ['BC', 'BK', 'CC', 'CL', 'CM', 'DF', 'DG', 'DS', 'HL', 'IF', 'JD', 'JS', 'LD', 'LP', 'LS', 'PO', 'RI',
               'SD', 'SG', 'TO']


# Retreive 9 random images from directory
files = Path(IMDIR).resolve().glob('*.*')
print(files)

images = random.sample(list(files), 1)
print(images)
# Configure plots
fig = plt.figure(figsize=(9, 9))
rows, cols = 3, 3

# Preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    # transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(0.5306, 0.1348)
])

# Enable gpu mode, if cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Perform prediction and plot results
with torch.no_grad():
    for num, img in enumerate(images):
        img = Image.open(img).convert('RGB')
        inputs = preprocess(img).unsqueeze(0).cpu()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        print(preds)
        label = class_names[preds]
        plt.subplot(rows, cols, num + 1)
        plt.title("Pred: " + label)
        plt.axis('off')
        plt.imshow(img)
'''
Sample run: python test.py test
'''

追溯

Traceback (most recent call last):
  File "/media/khawar/HDD_Khawar/CVPR/pytorch-cifar100/test_box.py", line 57, in <module>
    label = class_names[preds]
IndexError: list index out of range

【问题讨论】:

  • print(preds) 的输出是什么?很明显,max 函数返回的值高于 class_names 的长度。
  • 张量([86])。我需要类标签名称之类的输出
  • 您正在加载经过 imagenet 训练的 resnet18(似乎也没有预训练,这是一个错误吗?),它有 1000 个类。您正在尝试将这些分类到您的 10-20 类名称中。
  • 这个 resnet18 是我的模型
  • 你根本没有加载模型?

标签: pytorch classification prediction image-classification


【解决方案1】:

您的错误源于您没有对 resnet 模型的线性层进行任何修改。

我建议添加此代码:

# What you have
model = models.resnet18()

# What you need
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, len(class_names)))

这会将最后的线性层更改为输出正确数量的节点

萨塔克

【讨论】:

    猜你喜欢
    • 2011-10-31
    • 2015-06-26
    • 1970-01-01
    • 1970-01-01
    • 2020-01-27
    相关资源
    最近更新 更多