【发布时间】:2019-07-23 03:06:53
【问题描述】:
我想将几个参数传递给一个应该通过超参数和multiprocessing优化的函数。
如果 Enum 作为参数传递,则此操作失败。请参阅下面的代码。
在这种情况下我如何传递Enum?
from sklearn.model_selection import ParameterGrid
from multiprocessing import Pool
from enum import Enum
class MyStrategy(Enum):
var1 = 1
var2 = 2
var1 = MyStrategy(1)
var2 = MyStrategy(2)
abc = [1, 2]
xyz = [3, 4]
if True:
pg = [{'variant': var1,
'abc': abc,
'xyz': xyz, },
{'variant': var2,
'abc': abc, }]
else:
pg = [{'variant': '1',
'abc': abc,
'xyz': xyz, },
{'variant': '2',
'abc': abc, }]
parameterGrid = ParameterGrid(pg)
def myFunc(myParam):
print(myParam)
pool = Pool(1)
myList = pool.map(myFunc, parameterGrid)
如果False:
{'abc': 1, 'variant': '1', 'xyz': 3}
{'abc': 1, 'variant': '1', 'xyz': 4}
{'abc': 2, 'variant': '1', 'xyz': 3}
{'abc': 2, 'variant': '1', 'xyz': 4}
{'abc': 1, 'variant': '2'}
{'abc': 2, 'variant': '2'}
并以True 失败:
TypeError: object of type 'MyStrategy' has no len()
【问题讨论】:
-
为什么将 var1 和 var2 定义为单独的枚举?如果你只做
pg = [{'variant': MyStrategy.var1,...和{'variant': MyStrategy.var2,...似乎可以工作 -
感谢您的评论。不幸的是,我得到了错误
TypeError: Parameter grid value is not iterable (key='variant', value=<MyStrateg.var1: 1>) -
哦,我的错,我的意思是
pg = [{'variant': [MyStrategy.var1],..和{'variant': [MyStrategy.var2],... -
不幸的是,这也会导致错误:
AttributeError: Can't get attribute 'myFunc' on <module '__main__' (built-in) >
标签: python scikit-learn python-multiprocessing hyperparameters