【问题标题】:How to pass fix hyperparameters as variables for Keras-Tuner?如何将修复超参数作为 Keras-Tuner 的变量传递?
【发布时间】:2025-12-06 15:00:01
【问题描述】:

我想使用 Keras Tuner 对 Keras 模型进行超参数调优。

import tensorflow as tf
from tensorflow import keras
import keras_tuner as kt

def model_builder(hp):

  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28)))

  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))
  model.add(keras.layers.Dense(10))

  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

tuner.search(train_X, train_y, epochs=50)

到目前为止,一切都很好。但是,我还想定义一些模型参数(如输入图像尺寸)作为model_builder的输入参数,我一无所知,该怎么做:

def model_builder(hp, img_dim1, img_dim2):

  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(img_dim1, img_dim2)))
...

tuner = kt.Hyperband(model_builder(img_dim1, img_dim2),
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

貌似不行。如何将img_dim1, img_dim2 提供给hp 以外的模型?

【问题讨论】:

    标签: python machine-learning keras keras-tuner


    【解决方案1】:

    一个简单的解决方案是在 python 中使用“部分函数”,如下所示:

    from functools import partial
    
    #...
    
    model_builder_ready = partial(model_builder, img_dim1 = value1, img_dim2 = value2)
    
    tuner = kt.Hyperband(model_builder_ready,
                         objective='val_accuracy',
                         max_epochs=10,
                         factor=3)
    

    【讨论】:

    • 奇怪:tuner = RandomSearch(model_builder_ready,objective=Objective("val_f1_m", direction="max"), max_trials=2, executions_per_trial=2) 给出错误:TypeError: model_builder() got multiple values对于参数“img_dim1”
    • @Fredrik 是的,这很奇怪!你需要更仔细地调试你的代码。