【问题标题】:How to pass Enum to ParameterGrid?如何将枚举传递给 ParameterGrid?
【发布时间】:2019-07-23 03:06:53
【问题描述】:

我想将几个参数传递给一个应该通过超参数和multiprocessing优化的函数。

如果 Enum 作为参数传递,则此操作失败。请参阅下面的代码。

在这种情况下我如何传递Enum

from sklearn.model_selection import ParameterGrid
from multiprocessing import Pool
from enum import Enum


class MyStrategy(Enum):
    var1 = 1
    var2 = 2


var1 = MyStrategy(1)
var2 = MyStrategy(2)
abc = [1, 2]
xyz = [3, 4]
if True:
    pg = [{'variant': var1,
           'abc': abc,
           'xyz': xyz, },
          {'variant': var2,
           'abc': abc, }]
else:
    pg = [{'variant': '1',
           'abc': abc,
           'xyz': xyz, },
          {'variant': '2',
           'abc': abc, }]
parameterGrid = ParameterGrid(pg)

def myFunc(myParam):
    print(myParam)

pool = Pool(1)
myList = pool.map(myFunc, parameterGrid)

如果False:

确实有效
{'abc': 1, 'variant': '1', 'xyz': 3}
{'abc': 1, 'variant': '1', 'xyz': 4}
{'abc': 2, 'variant': '1', 'xyz': 3}
{'abc': 2, 'variant': '1', 'xyz': 4}
{'abc': 1, 'variant': '2'}
{'abc': 2, 'variant': '2'}

并以True 失败:

TypeError: object of type 'MyStrategy' has no len()

【问题讨论】:

  • 为什么将 var1 和 var2 定义为单独的枚举?如果你只做 pg = [{'variant': MyStrategy.var1,...{'variant': MyStrategy.var2,... 似乎可以工作
  • 感谢您的评论。不幸的是,我得到了错误TypeError: Parameter grid value is not iterable (key='variant', value=<MyStrateg.var1: 1>)
  • 哦,我的错,我的意思是 pg = [{'variant': [MyStrategy.var1],..{'variant': [MyStrategy.var2],...
  • 不幸的是,这也会导致错误:AttributeError: Can't get attribute 'myFunc' on <module '__main__' (built-in) >

标签: python scikit-learn python-multiprocessing hyperparameters


【解决方案1】:

首先,您需要将枚举中的值更改为字符串,因为这与您在False 块中用于variant 键的数据类型相同。像这样

from sklearn.model_selection import ParameterGrid
from multiprocessing import Pool
from enum import Enum


class MyStrategy(Enum):
    var1 = '1'  #<--------Notice the value is string not int
    var2 = '2'

接下来,使用如下值:

var1 = MyStrategy.var1.value
var2 = MyStrategy.var2.value

【讨论】:

    【解决方案2】:

    True 示例中,在pg 字典列表中,'variant' 的两个 都是&lt;enum 'MyStrategy'&gt; 类型。而在False 示例中输入str

    要使用枚举器复制 False 变体,您必须更改以下内容:

    class MyStrategy(Enum):
        var1 = '1'
        var2 = '2'
    
    pg = [{'variant': MyStrategy('1').value,  # Or MyStrategy.var1.value
               'abc': abc,
               'xyz': xyz, },
          {'variant': MyStrategy('2').value,  # Or MyStrategy.var2.value
               'abc': abc, }]
    
    

    【讨论】:

    • 您好,感谢您的回答。不幸的是,这会导致错误:AttributeError: Can't get attribute 'myFunc' on &lt;module '__main__' (built-in) &gt;
    • 嗨@user7468395,我的解决方案使用Enum 复制了False pg 块的行为。 MyStrategy('1').value 都返回一个字符串,所以如果您的 False pg 有效,我的解决方案也应该有效...但是,您似乎正在修改 pg 中的 parameterGrid,我邀请您修改您的问题,包括实现代码parameterGrid。您还可以通过以下方式访问Enum 对象:MyStrategy.__members__,查看文档here。它只不过是一个只读 dict!!
    • 嗨@jesteras,不,除了问题中可见的内容之外,我什么都不做......所以,我正在做的事情的完整代码在我的问题中可见。但是,关于您在上一条评论中提到的MyStrategy.__members__,您的整体解决方案如何?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2023-02-06
    • 2011-08-15
    • 1970-01-01
    • 2011-08-03
    • 1970-01-01
    • 2010-09-05
    • 2012-10-29
    相关资源
    最近更新 更多