【发布时间】:2020-05-25 01:38:06
【问题描述】:
我需要可视化分类 14 个不同类别的 Vgg16 模型的输出。 我加载了训练好的模型,我确实用 identity() 层替换了分类器层,但它没有对输出进行分类。
这里是sn-p: 这里的样本数量是 1000 张图片。
epoch = 800
PATH = 'vgg16_epoch{}.pth'.format(epoch)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
model.classifier._modules['6'] = Identity()
model.eval()
logits_list = numpy.empty((0,4096))
targets = []
with torch.no_grad():
for step, (t_image, target, classess, image_path) in enumerate(test_loader):
t_image = t_image.cuda()
target = target.cuda()
target = target.data.cpu().numpy()
targets.append(target)
logits = model(t_image)
print(logits.shape)
logits = logits.data.cpu().numpy()
print(logits.shape)
logits_list = numpy.append(logits_list, logits, axis=0)
print(logits_list.shape)
tsne = TSNE(n_components=2, verbose=1, perplexity=10, n_iter=1000)
tsne_results = tsne.fit_transform(logits_list)
target_ids = range(len(targets))
plt.scatter(tsne_results[:,0],tsne_results[:,1],c = target_ids ,cmap=plt.cm.get_cmap("jet", 14))
plt.colorbar(ticks=range(14))
plt.legend()
plt.show()
【问题讨论】:
标签: python-3.x matplotlib pytorch scatter-plot