【问题标题】:How does SGD Incremental Learning behave in scikit-learn?SGD 增量学习在 scikit-learn 中的表现如何?
【发布时间】:2020-01-11 16:37:29
【问题描述】:

我正在学习 Scikit-learn 中的增量学习算法。 sci-kit learn 中的 SGD 就是这样一种算法,它允许通过传递块/批次来增量学习。

  • sci-kit learn 是否将所有用于训练数据的批次保存在内存中?
  • 或者它是否将内存中的块/批次保持到一定大小?
  • 还是在内存中训练时只保留一个块/批次,并在训练后移除其他经过训练的块/批次?这是否意味着它会遭受灾难性遗忘?

【问题讨论】:

    标签: python machine-learning scikit-learn


    【解决方案1】:

    增量学习的目的是将整个训练数据保存在内存中。因此,可以在整体上不适合内存的大数据集上进行学习。如果训练数据逐个可用,增量学习也很有用。

    随机梯度下降 (SGD) 不会在内存中保留任何批次,除了它正在处理的批次。但是,这并不意味着它会立即忘记过去的补丁。批次用于计算梯度,用于更新模型系数。因此,尽管数据本身被丢弃,但批次中包含的信息仍保留在模型中。

    由于梯度是使用最新批次更新的,因此新批次对模型当前训练状态的影响比旧批次更大。你可以说最近的批次在模型的记忆中更加生动,而它逐渐忘记了旧的批次。

    这里有一个玩具例子来说明这个问题(代码在底部):

    一个 SGD 分类器在前 100 个批次中使用三个类进行增量训练。在训练数据中不存在批次 100-200 第 3 类。很明显,分类器“忘记”了它之前学到的关于这个类的一切。您可以将这种效果标记为“灾难性遗忘”,或者您可能将其视为可取的“适应数据变化”;解释取决于用例。

    所以,是的,SGD 确实似乎受到catastrophic forgetting 的影响。不过,我认为这没什么大不了的。只是在特定应用程序中设计训练策略时必须注意的事项。

    import numpy as np
    from sklearn.linear_model import SGDClassifier
    from sklearn.datasets import make_blobs
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt
    
    np.random.seed(42)
    n_features = 150
    centers = np.concatenate([np.eye(3)*3, np.zeros((3, n_features-3))], axis=1)
    
    x_test, y_test = make_blobs([100, 100, 100], centers=centers)
    
    cla = SGDClassifier()
    performance = []
    
    def train_some_batches(n_samples_per_class):
        for _ in range(100):
            x_batch, y_batch = make_blobs(n_samples_per_class, centers=centers)
            cla.partial_fit(x_batch, y_batch, classes=[0, 1, 2])
            conf = confusion_matrix(y_test, cla.predict(x_test))
            performance.append(np.diag(conf) / np.sum(conf, axis=1))
    
    train_some_batches([50, 50, 50])
    train_some_batches([50, 50, 0])            
    
    plt.plot(performance)
    plt.legend(['class 1', 'class 2', 'class 3'])
    plt.xlabel('training batches')
    plt.ylabel('accuracy')
    
    plt.show()
    

    【讨论】:

    • 非常感谢!我想知道“partial_fit”和常规“fit”之间是否有任何区别。
    • @Farzana 好吧,常规的fit 是终极遗忘体验。它使用您提供的任何数据从头开始完全训练分类器。
    • 很抱歉再次问你同样的问题。我只是想确保我清楚地理解它。所以,我的理解是'fit'完全忘记了它在分类后学到的东西,但'partial_fit'逐渐忘记了。我说的对吗?
    • @Farzana 差不多。 partial_fit 更新分类器,但 fit “覆盖”它。如果您多次调用“fit”,则只有最后一个很重要。
    • > 只有最后一个重要...除非warm_start 为真
    猜你喜欢
    • 1970-01-01
    • 2016-06-18
    • 2016-03-12
    • 2020-04-02
    • 2020-11-15
    • 2015-08-26
    • 2019-02-06
    • 2017-07-16
    • 2017-11-26
    相关资源
    最近更新 更多