【发布时间】:2021-07-12 00:47:38
【问题描述】:
我正在对写入所有类名称的图像进行预测,在测试文件夹中,我有 20 张图像。请给我一些提示,为什么我会出错?我们如何检查模型的索引?
代码
import numpy as np
import sys, random
import torch
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import glob
# Paths for image directory and model
IMDIR = './test'
MODEL = 'checkpoint/resnet18/Monday_31_May_2021_21h_25m_05s/resnet18-1000-regular.pth'
# Load the model for testing
model = models.resnet18()
model.named_children()
torch.save(model.state_dict, MODEL)
model.eval()
# Class labels for prediction
class_names = ['BC', 'BK', 'CC', 'CL', 'CM', 'DF', 'DG', 'DS', 'HL', 'IF', 'JD', 'JS', 'LD', 'LP', 'LS', 'PO', 'RI',
'SD', 'SG', 'TO']
# Retreive 9 random images from directory
files = Path(IMDIR).resolve().glob('*.*')
print(files)
images = random.sample(list(files), 1)
print(images)
# Configure plots
fig = plt.figure(figsize=(9, 9))
rows, cols = 3, 3
# Preprocessing transformations
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
# transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(0.5306, 0.1348)
])
# Enable gpu mode, if cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Perform prediction and plot results
with torch.no_grad():
for num, img in enumerate(images):
img = Image.open(img).convert('RGB')
inputs = preprocess(img).unsqueeze(0).cpu()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
print(preds)
label = class_names[preds]
plt.subplot(rows, cols, num + 1)
plt.title("Pred: " + label)
plt.axis('off')
plt.imshow(img)
'''
Sample run: python test.py test
'''
追溯
Traceback (most recent call last):
File "/media/khawar/HDD_Khawar/CVPR/pytorch-cifar100/test_box.py", line 57, in <module>
label = class_names[preds]
IndexError: list index out of range
【问题讨论】:
-
print(preds) 的输出是什么?很明显,max 函数返回的值高于 class_names 的长度。
-
张量([86])。我需要类标签名称之类的输出
-
您正在加载经过 imagenet 训练的 resnet18(似乎也没有预训练,这是一个错误吗?),它有 1000 个类。您正在尝试将这些分类到您的 10-20 类名称中。
-
这个 resnet18 是我的模型
-
你根本没有加载模型?
标签: pytorch classification prediction image-classification