【发布时间】:2020-06-18 08:39:46
【问题描述】:
是否有任何脚本/函数来拆分数据,计算每个图像中的类出现次数并平衡它们? 我以这种方式尝试过 sklearn train_test_split:
data = pd.read_csv('train_labels.csv')
data.head()
类是我想要预测的,在一张图像上我可以有 0..n 个矩形,每个矩形都有一个类。
data = data.drop_duplicates(subset="filename")
y = data['class']
X = data.drop('class',axis = 1)
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2)
但是,当我删除文件名中的重复项时,我会丢失信息,也许我会将文件发送到与许多其他类一起训练或测试,但如果我不删除它们,我可以在训练和测试中拥有一个文件。
感谢您的帮助。
【问题讨论】:
-
您可能正在寻找
train_test_split的stratify参数。我想电话是train_test_split(X, y, test=0.2, stratify=y)。 -
如果您希望平衡不平衡的类,您可以通过不同的策略(例如过采样)来实现。查看
imbalanced-learnPython 包上的文档。
标签: machine-learning scikit-learn computer-vision object-detection train-test-split