【发布时间】:2020-12-18 07:34:33
【问题描述】:
为了打印具有多个模型的 ROC 曲线,我遇到了这个特殊错误。需要帮助
from tensorflow.keras.models import load_model
def dense():
return (load_model('DenseNet201.h5'))
def mobile():
return(load_model('MobileNet.h5'))
def res():
return(load_model('ResNet50V2.h5'))
def vgg():
return(load_model('VGG16.h5'))
models = [
{
'label': 'DenseNet201',
'model': dense(),
},
{
'label': 'MobileNet',
'model':mobile(),
},
{
'label': 'ResNet50V2',
'model':res(),
},
{
'label': 'VGG16',
'model':vgg(),
}]
from sklearn import metrics
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
plt.figure()
# Below for loop iterates through your models list
for m in models:
model = m['model'] # select the model
#model.fit(X_train, y_train) # train the model
y_pred=model.predict(X_test) # predict the test data
# Compute False postive rate, and True positive rate
#fpr, tpr, thresholds = metrics.roc_curve(y_test, model.y_pred_bin(X_test)[:,1])
fpr, tpr, thresholds = metrics.roc_curve(y_test, model.predict_proba(X_test)[:,1])
# Calculate Area under the curve to display on the plot
auc = metrics.roc_auc_score(y_test,model.predict(X_test))
# Now, plot the computed values
plt.plot(fpr, tpr, label='%s ROC (area = %0.2f)' % (m['label'], auc))
# Custom settings for the plot
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity(False Positive Rate)')
plt.ylabel('Sensitivity(True Positive Rate)')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show() # Display
我在一个函数中加载了我的预训练模型并使用此代码返回它。我创建了一个列表,它将被迭代,它会调用加载这些模型的函数,因此将绘制每个模型的 ROC 曲线。
完整的追溯
> AttributeError Traceback (most recent call
> last) <ipython-input-43-f353a6208636> in <module>()
> 11 # Compute False postive rate, and True positive rate
> 12 #fpr, tpr, thresholds = metrics.roc_curve(y_test, model.y_pred_bin(X_test)[:,1])
> ---> 13 pred_prob = model.predict_proba(X_test)
> 14 fpr, tpr, thresholds = metrics.roc_curve(y_test, pred_prob[:,1])
> 15 # Calculate Area under the curve to display on the plot
>
> AttributeError: 'Functional' object has no attribute 'predict_proba'
【问题讨论】:
-
能否包含完整的回溯?
-
predict_proba仅适用于某些类型的模型...很可能您的列表中有一个模型不支持该方法...您可以找出导致错误的模型与try:model.predict_proba(X_test)except:print(model_name) -
请看一看。我用完整的回溯更新了代码
标签: python matplotlib scikit-learn roc