【问题标题】:Load models in Keras在 Keras 中加载模型
【发布时间】:2020-05-21 22:20:57
【问题描述】:

我使用此代码使用客户指标 (AUC) 在 Keras 中加载模型,但这不起作用。你能帮我解决这个问题吗?

train_datagen = ImageDataGenerator(rescale=1/255)
val_datagen = ImageDataGenerator(rescale=1/255)

train_generator = train_datagen.flow_from_directory(
                        train_dir,
                        target_size=(32, 32),
                        batch_size=10,
                        class_mode='binary')
val_generator = val_datagen.flow_from_directory(
                        val_dir, 
                        target_size=(32, 32),
                        batch_size=10,
                        class_mode='binary')

model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', 
              optimizer='rmsprop', 
              metrics=[keras.metrics.AUC(name='auc')])

history = model.fit_generator(train_generator,
                              steps_per_epoch=1405,
                              epochs=1,
                              validation_data=val_generator,
                              validation_steps=10)

model.save('baseline.h5')

model1 = models.load_model('baseline.h5')

我收到了一个 ValueError

ValueError: Unknown metric function: {'class_name': 'AUC', 'config': {'name': 'auc', 'dtype': 'float32', 'num_thresholds': 200, 'curve': 'ROC', 'summation_method': 'interpolation', 'thresholds': [0.005025125628140704, 0.010050251256281407, 0.01507537688442211, 0.020100502512562814

编辑:我添加了导入。我听说过 load_model 方法中的参数“customer_objects”。但我试过了:'custom_object'={'auc':keras.metrics.AUC(name='auc')}

from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
from keras import models
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import os
from sklearn import metrics
from tensorflow import keras

【问题讨论】:

  • 请在这个例子中添加导入,它们很重要
  • 是的,我添加了导入

标签: keras model auc


【解决方案1】:

只是不要编译模型:

model1 = models.load_model('baseline.h5', compile=False)
model1.compile(loss='binary_crossentropy', 
              optimizer='rmsprop', 
              metrics=[keras.metrics.AUC()])

【讨论】:

  • 通过设置compile=False,您将摆脱在compile 函数中定义的有关损失、优化器和其他内容的信息。如果你认为,答案是正确的,请标记它
猜你喜欢
  • 2018-11-07
  • 1970-01-01
  • 1970-01-01
  • 2021-11-30
  • 2020-08-26
  • 2020-06-04
  • 1970-01-01
  • 2019-08-03
  • 2019-03-26
相关资源
最近更新 更多