【问题标题】:How to get training & validation loss of Keras scikit-learn wrapper in cross validation?如何在交叉验证中获得 Keras scikit-learn 包装器的训练和验证损失?
【发布时间】:2020-07-07 01:48:26
【问题描述】:

我知道 keras 中的 model.fit 返回一个 callbacks.History 对象,我们可以从中获取损失和其他指标,如下所示。

...
train_history = model.fit(X_train, Y_train,
                    batch_size=batch_size, nb_epoch=nb_epoch,
                    verbose=1, validation_data=(X_test, Y_test))
loss = train_history.history['loss']
val_loss = train_history.history['val_loss']

但是,在我的新实验中,我使用 cross validation 和使用 kerasclassifier 的 keras 模型(完整示例代码:https://chrisalbon.com/deep_learning/keras/k-fold_cross-validating_neural_networks/

# Wrap Keras model so it can be used by scikit-learn
neural_network = KerasClassifier(build_fn=create_network, 
                                 epochs=10, 
                                 batch_size=100, 
                                 verbose=1)

由于现在我使用交叉验证,我不确定如何获得训练和验证损失。

【问题讨论】:

    标签: python machine-learning keras scikit-learn cross-validation


    【解决方案1】:

    正如documentation 中明确提到的,cross_val_score 包含一个scoring 参数,即

    类似于cross_validate,但只允许使用一个指标。

    因此它不能用于返回 Keras model.fit() 的所有损失和度量信息。

    Keras 的 scikit-learn 包装器旨在提供便利,前提是您对所有底层细节(例如训练和验证损失和准确性)并不真正感兴趣。如果不是这种情况,您应该恢复直接使用 Keras。以下是使用链接到的示例和this answer of mine 的元素的方法:

    import numpy as np
    from keras import models, layers
    from sklearn.datasets import make_classification
    from sklearn.model_selection import KFold
    
    np.random.seed(0)
    
    # Number of features
    number_of_features = 100
    
    # Generate features matrix and target vector
    features, target = make_classification(n_samples = 10000,
                                           n_features = number_of_features,
                                           n_informative = 3,
                                           n_redundant = 0,
                                           n_classes = 2,
                                           weights = [.5, .5],
                                           random_state = 0)
    
    def create_network():
        network = models.Sequential()
        network.add(layers.Dense(units=16, activation='relu', input_shape=(number_of_features,)))
        network.add(layers.Dense(units=16, activation='relu'))
        network.add(layers.Dense(units=1, activation='sigmoid'))
    
        network.compile(loss='binary_crossentropy', 
                        optimizer='rmsprop', 
                        metrics=['accuracy']) 
    
        return network
    
    n_splits = 3
    kf = KFold(n_splits=n_splits, shuffle=True)
    
    loss = []
    acc = []
    val_loss = []
    val_acc = []
    
    # cross validate:
    for train_index, val_index in kf.split(features):
        model = create_network()
        hist = model.fit(features[train_index], target[train_index],
                         epochs=10,
                         batch_size=100,
                         validation_data = (features[val_index], target[val_index]),
                         verbose=0)
        loss.append(hist.history['loss'])
        acc.append(hist.history['acc'])
        val_loss.append([hist.history['val_loss']])
        val_acc.append(hist.history['val_acc'])
    

    之后,例如loss 将是:

    [[0.7251979386058971,
      0.6640552306833333,
      0.6190941931069023,
      0.5602273066015956,
      0.48771809028534785,
      0.40796665995284814,
      0.33154681897220617,
      0.2698465999525444,
      0.227492357244586,
      0.1998490962115201],
     [0.7109123742507104,
      0.674812126485093,
      0.6452083222258479,
      0.6074533335751673,
      0.5627432800365635,
      0.51291748379345,
      0.45645068427406726,
      0.3928780094229408,
      0.3282097149542538,
      0.26993170230619656],
     [0.7191790426458682,
      0.6618405645963258,
      0.6253172250296091,
      0.5855853647883192,
      0.5438901918195831,
      0.4999895181964501,
      0.4495182811042725,
      0.3896359298090465,
      0.3210068798340545,
      0.25932698793518183]]
    

    n_splits 列表的列表(这里是 3 个),每个列表都包含每个 epoch 的训练损失(这里是 10 个)。其他列表也一样...

    【讨论】:

      猜你喜欢
      • 2018-12-07
      • 2019-04-02
      • 2014-05-01
      • 2016-04-25
      • 2017-04-12
      • 2017-09-02
      • 2017-07-17
      • 2015-06-22
      • 2021-12-10
      相关资源
      最近更新 更多