【发布时间】:2020-06-17 14:09:25
【问题描述】:
Pytorch 的torchvision 包提供pre-trained neural networks 用于图像分类。我一直在使用以下代码使用 Alexnet 对图像进行分类(注意:其中一些代码来自 this webpage):
from PIL import Image
import torch
from torchvision import transforms
from torchvision import models
# function to transform image
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# image
img = Image.open('/path/to/image.jpg')
img = transform(img)
img = torch.unsqueeze(img, 0)
# alexnet
alexnet = models.alexnet(pretrained=True)
alexnet.eval()
out = alexnet(img)
percents = torch.nn.functional.softmax(out, dim=1)[0] * 100
top5_vals, top5_inds = percents.topk(5)
共有 1,000 个类,top5_inds 变量为我提供了前 5 个类的索引。但是我如何获得相关的标签(例如蜗牛、篮球、香蕉)?我似乎找不到任何类型的列表作为 Pytorch 文档或 alexnet 变量的一部分。
【问题讨论】:
标签: image-processing deep-learning classification pytorch torchvision