【发布时间】:2014-03-22 15:49:13
【问题描述】:
这个问题是针对 Python 库 scikit-learn 的。请让我知道将其发布在其他地方是否更好。谢谢!
现在的问题...
我有一个基于 BaseEstimator 的前馈神经网络类 ffnn,我使用 SGD 进行训练。它工作正常,我也可以使用 GridSearchCV() 并行训练它。
现在我想在函数 ffnn.fit() 中实现提前停止,但为此我还需要访问折叠的验证数据。一种方法是更改 sklearn.grid_search.fit_grid_point() 中的行
clf.fit(X_train, y_train, **fit_params)
变成类似的东西
clf.fit(X_train, y_train, X_test, y_test, **fit_params)
并更改 ffnn.fit() 以获取这些参数。这也会影响 sklearn 中的其他分类器,这是一个问题。我可以通过检查 fit_grid_point() 中的某种标志来避免这种情况,该标志告诉我何时以上述两种方式之一调用 clf.fit()。
在我不必编辑 sklearn 库中的任何代码的情况下,有人可以建议一种不同的方法吗?
或者,进一步将 X_train 和 y_train 随机拆分为训练/验证集并检查一个好的停止点,然后在所有 X_train 上重新训练模型是否正确?
谢谢!
【问题讨论】:
标签: python machine-learning neural-network scikit-learn cross-validation