【发布时间】:2017-02-12 21:08:46
【问题描述】:
scikit-learn 的train_test_split() 函数中的stratify 参数有问题。这是一个虚拟示例,在我的数据中随机出现相同的问题:
from sklearn.model_selection import train_test_split
a = [1, 0, 0, 0, 0, 0, 0, 1]
train_test_split(a, stratify=a, random_state=42)
返回:
[[1, 0, 0, 0, 0, 1], [0, 0]]
它不应该在测试子集中也选择一个“1”吗?从我期望train_test_split() 和stratify 的工作方式来看,它应该返回如下内容:
[[1, 0, 0, 0, 0, 0], [0, 1]]
random_state 的某些值会发生这种情况,而其他值则可以正常工作;但是每次我必须分析数据时,我都无法搜索它的“正确”值。
我有 python 2.7 和 scikit-learn 0.18。
【问题讨论】:
-
如果您尝试使用
stratify=np.unique(a)会怎样? -
很遗憾,它不起作用,因为传递给
stratify的列表必须与要拆分的列表长度相同。 -
文档中没有任何地方声明即使在很小的子集中也会有所有类。如果您将唯一的 1 添加到列表中,那么您将在测试拆分中获得 1 类。我认为它应该与您的火车拆分中的第 1 类部分相同。例如,如果您删除“分层”,那么您将得到列表的尾部,而不是带有混洗类的列表。
标签: python python-2.7 scikit-learn