【问题标题】:Scikit-Learn GridSearchCV failing on on a gensim LDA modelScikit-Learn GridSearchCV 在 gensim LDA 模型上失败
【发布时间】:2020-06-21 11:22:05
【问题描述】:

这是创建模型的代码:

import gensim
NUM_TOPICS = 4
ldamodel = gensim.models.ldamodel.LdaModel(corpus,num_topics = 
NUM_TOPICS,id2word=dictionary,passes=100)
ldamodel.save('model5.gensim')
topics = ldamodel.print_topics(num_words=4)
print(topics)

这是 GridSearchCV 的代码:

search_params = {'n_components': [4, 6, 8, 10, 20], 'learning_decay': [.5, .7, .9]}


# Init Grid Search Class
model = GridSearchCV(ldamodel, param_grid=search_params)

# Do the Grid Search
model.fit(data_vectorized)

这是输出:

*---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-108-1a35c49ac19e> in <module>
      9 
     10 # Do the Grid Search
---> 11 model.fit(data_vectorized)
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
    627 
    628         scorers, self.multimetric_ = _check_multimetric_scoring(
--> 629             self.estimator, scoring=self.scoring)
    630 
    631         if self.multimetric_:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in _check_multimetric_scoring(estimator, scoring)
    471     if callable(scoring) or scoring is None or isinstance(scoring,
    472                                                           str):
--> 473         scorers = {"score": check_scoring(estimator, scoring=scoring)}
    474         return scorers, False
    475     else:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in check_scoring(estimator, scoring, allow_none)
    399     if not hasattr(estimator, 'fit'):
    400         raise TypeError("estimator should be an estimator implementing "
--> 401                         "'fit' method, %r was passed" % estimator)
    402     if isinstance(scoring, str):
    403         return get_scorer(scoring)
TypeError: estimator should be an estimator implementing 'fit' method, <gensim.models.ldamodel.LdaModel object at 0x000002121E55D3C8> was passed*

【问题讨论】:

    标签: python scikit-learn gensim lda gridsearchcv


    【解决方案1】:

    您正在尝试使用 scikit-learn 包中的 GridSearchCV 对象,该包需要运行它的模型对象来实现某些方法(特别是在错误消息中:fit 方法)。由于scikit-learngensim 没有任何关系,您需要确保它们与subclassing an Estimator class in scikit-learn 兼容,并在fit 方法中封装gensim 训练。

    此外,在我看来,the LdaModel documentation 并没有使用您尝试搜索的参数(n_componentslearning_decay)。您只能搜索模型使用的参数值。

    【讨论】:

      猜你喜欢
      • 2013-10-30
      • 1970-01-01
      • 1970-01-01
      • 2021-04-24
      • 2014-10-13
      • 2016-12-07
      • 2013-10-01
      • 2016-02-27
      • 2019-03-30
      相关资源
      最近更新 更多