【问题标题】:Getting "TypeError: cannot pickle '_thread.RLock' object" when saving model with pickle使用泡菜保存模型时出现“TypeError:无法泡菜'_thread.RLock'对象”
【发布时间】:2020-11-17 13:20:22
【问题描述】:

我正在尝试将我的 keras 模型保存到一个 pickle 文件中,但是我收到了这个错误。有什么办法可以解决?或者保存和加载模型的更好方法是什么?我正在二进制预测 480x640 灰度图像。

按照我的代码:

def trainModel(data):
  batch_size = 3
  img_height = 480
  img_width = 640

  trainDataset = tf.keras.preprocessing.image_dataset_from_directory(
    data,
    validation_split=0.25,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    #class_names={"nao_doentes", "doentes"}
  )

  valDataset = tf.keras.preprocessing.image_dataset_from_directory(
    data,
    validation_split=0.25,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    #class_names={"nao_doentes", "doentes"}
  )

  AUTOTUNE = tf.data.experimental.AUTOTUNE

  trainDataset = trainDataset.cache().prefetch(buffer_size=AUTOTUNE)
  valDataset = valDataset.cache().prefetch(buffer_size=AUTOTUNE)

  num_classes = 2

  model = tf.keras.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes)
  ])

  model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['accuracy']
  )

  model.fit(
    trainDataset,
    validation_data=valDataset,
    epochs=10
  )
  return model

model = trainModel(training_data)
with open('model.sav', 'wb') as f:
  pickle.dump(model, f)

with open('model.sav', 'rb') as f:
  model = pickle.load(f)

testing = np.ndarray(shape=(1, 1, 480, 640), dtype=np.float32)
image = load_img(os.path.join(test_data, "doentes/doente_6.jpg"), target_size=(480,640))
x = img_to_array(image)
x = np.expand_dims(x, axis=0)
testing = np.vstack([x])
print(model.predict(testing))

另外,当问题涉及图像分类情况时,您能否提供一些良好实践和解释的良好来源提供建议?我是该地区的新手,因此在搜索和链接不同来源的信息时我有些吃力。

【问题讨论】:

    标签: python keras model pickle


    【解决方案1】:

    通常,pickle 在为 pytorch、tensorflow 和 keras 保存 ml 模型权重时存在问题。要保存您的 keras 模型,请查看 their tutorials

    具体来说,尝试使用 keras 中的 save 和 load_module 函数:

    model.save('path/to/location')
    reconstructed_model = keras.models.load_model("path/to/location")
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2011-02-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-07-29
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多