【发布时间】:2020-04-21 13:49:29
【问题描述】:
我想做类激活图,所以我写了代码
from keras.datasets import mnist
from keras.layers import Conv2D, Dense, GlobalAveragePooling2D
from keras.models import Model, Input
from keras.utils import to_categorical
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train_resized = x_train.reshape((60000, 28, 28, 1))
x_test_resized = x_test.reshape((10000, 28, 28, 1))
y_train_hot_encoded = to_categorical(y_train)
y_test_hot_encoded = to_categorical(y_test)
inputs = Input(shape=(28,28, 1))
x = Conv2D(64, (3,3), activation='relu')(inputs)
x = Conv2D(64, (3,3), activation='relu')(x)
x = GlobalAveragePooling2D()(x)
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train_resized, y_train_hot_encoded, epochs=30, batch_size=256, shuffle=True, validation_split=0.3)
工作正常,所以我已经导入了 Visualize_cam 模块
from vis.visualization import visualize_cam
import matplotlib.pyplot as plt
import numpy as np
for i in range(10):
ind = np.where(y_test == i)[0][0]
plt.subplot(141)
plt.imshow(x_test_resized[ind].reshape((28,28)))
for j,modifier in enumerate([None, 'guided', 'relu']):
heat_map = visualize_cam(model, 4, y_test[ind], x_test_resized[ind], backprop_modifier=modifier)
plt.subplot(1,4,j+2)
plt.imshow(heat_map)
plt.show()
但visualize_cam 运行不佳
我尝试了很多次来修复模块,但它并不顺利 (这取决于 scipy 哪个版本低于 1.3。但是)
所以我必须在没有那个模块的情况下实现 cam
是否有任何解决方案可以将 Visualize_cam 替换为其他选项来实现 CAM?
【问题讨论】:
标签: python machine-learning neural-network conv-neural-network