【问题标题】:How to set class weights in DecisionTreeClassifier for multi-class setting如何在 DecisionTreeClassifier 中设置类权重以进行多类设置
【发布时间】:2020-10-16 05:52:08
【问题描述】:

我正在使用sklearn.tree.DecisionTreeClassifier 来训练三类分类问题。

3个类的记录数如下:

A: 122038
B: 43626
C: 6678

当我训练分类器模型时,它无法学习类 - C。虽然效率是 65-70%,但它完全忽略了 C 类。

后来我知道了class_weight参数,但我不知道如何在多类设置中使用它。

这是我的代码:(我使用了balanced,但准确度更差)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)

如何使用与类分布成比例的权重。

其次,有没有更好的方法来解决这个不平衡类问题以提高准确性?

【问题讨论】:

    标签: python machine-learning scikit-learn decision-tree


    【解决方案1】:

    您还可以将值字典传递给 class_weight 参数,以设置您自己的权重。例如,将 A 级的重量减半:

    class_weight={
        'A': 0.5,
        'B': 1.0,
        'C': 1.0
    }
    

    通过执行 class_weight='balanced',它会自动设置与类频率成反比的权重。

    更多信息可以在 class_weight 参数下的文档中找到: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

    通常可以预期平衡类会降低准确性。这就是为什么准确性通常被认为是不平衡数据集的一个较差指标的原因。

    您可以尝试 sklearn 包含的 Balanced Accuracy 指标,但还有许多其他潜在指标可以尝试,这取决于您的最终目标。

    https://scikit-learn.org/stable/modules/model_evaluation.html

    如果您不熟悉“混淆矩阵”及其相关值(例如精度和召回率),那么我会从那里开始您的研究。

    https://en.wikipedia.org/wiki/Precision_and_recall

    https://en.wikipedia.org/wiki/Confusion_matrix

    https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    【讨论】:

      【解决方案2】:

      “平衡”模式是开始的方式。

      “平衡”模式使用 y 的值来自动调整 权重与输入数据中的类频率成反比 作为 n_samples / (n_classes * np.bincount(y))


      要手动定义权重,您需要字典字典列表,具体取决于问题。


      class_weight dict,dict 列表或“平衡”,默认=None

      与 {class_label: weight} 形式的类关联的权重。如果没有,所有的类都应该有一个权重。对于多输出 问题,可以按照与 y 列。

      请注意,对于多输出(包括多标签),应为 > 在其自己的字典中的每一列的每个类定义权重。例如,对于四类多标签 > 分类权重应该是 [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1 : 1}] 而不是 [{1:1}, {2:5}, {3:1}, {4:1}]。


      例子:

      如果A类的频率是10%,B类的频率是90%:

      clf = tree.DecisionTreeClassifier(class_weight={A:9,B:1})
      

      【讨论】:

        猜你喜欢
        • 2018-08-25
        • 2013-04-10
        • 1970-01-01
        • 2018-01-30
        • 2019-06-03
        • 2021-08-21
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多