【问题标题】:How do you to save and load a tensorflow model that has a feature_layer?如何保存和加载具有 feature_layer 的 tensorflow 模型?
【发布时间】:2020-01-14 14:18:51
【问题描述】:

我正在使用我自己的数据集来学习本教程:https://www.tensorflow.org/tutorials/load_data/csv。 这是我对模型的顺序:

model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=[1]),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
    ])

现在,我想保存经过训练的模型,以便在其他程序中使用该模型。

我使用的第一种方法是使用检查点回调。

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1, period=10)


def train():
    model.fit(train_data, epochs=50)#, callbacks=[cp_callback])
    test_loss, test_acc = model.evaluate([[0],[0.4],[0.6],[1],[4.56],[2.1]],  [[0],[0],[1],[1],[1],[0]])
    print('\nTest accuracy:', test_acc)

但是,这不起作用,因为在我想使用模型的其他程序中,我无法复制模型的确切形状:

def create_model():
  model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=[1]),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
    ])

我也试过这个:

numeric_column = tf.feature_column.numeric_column('numeric', shape=[1])
numeric_columns = [numeric_column]

preprocessing_layer = tf.keras.layers.DenseFeatures(numeric_columns)

def create_model():
  model = keras.Sequential([
    preprocessing_layer,
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
    ])

...但它仍然没有工作。

在那之后,我尝试使用model.save("myModel.h5"),但这根本不会保存,现在我迷路了。

请帮忙。谢谢你。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    试试Saver类成员函数:saver.savesaver.restore。 请参阅此tutorial 以获得明确的说明。

    【讨论】:

    • 可能不会,因为 Saver 在 Tensorflow 中保存了一个会话。对于 Keras,我们不必显式管理会话。 Keras 提供了保存和加载模型本身的功能,例如参见here
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-04-12
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多