【问题标题】:Error in counting the number of labels for a multi-label classification problem计算多标签分类问题的标签数量时出错
【发布时间】:2020-11-26 02:32:48
【问题描述】:

我正在尝试计算多标签分类问题的标签分布。请从 CSV 文件中查找包含的示例数据。

filenames   labels
tt3302594.jpg   ['deer']
tt2377194.jpg   ['deer']
tt2309762.jpg   ['dog', 'deer']
tt2870808.jpg   ['cat', 'deer']
tt2551396.jpg   ['cat', 'dog', 'deer']
tt4008652.jpg   ['dog']
tt2926810.jpg   ['deer']
tt3531604.jpg   ['dog', 'deer']
tt2290739.jpg   ['cat', 'deer']

我希望绘制一个 seaborn 图,其中 X 轴为单个标签,Y 轴为它们的计数值。

以下是代码:

import numpy as np
import pandas as pd
import seaborn as sns
from collections import Counter

train = pd.read_csv('example.csv')    # reading the csv file
meta = pd.DataFrame(train, columns=['filenames', 'labels'])
print(f'Found {len(meta)} images')
meta.sample(9)
all_labels = [label for lbs in meta['labels'] for label in lbs]
labels_count = Counter(all_labels)
ax = sns.countplot(all_labels, order=[k for k, _ in labels_count.most_common()], log=True)
ax.set_title('Number of images with a class label')
ax.set_ylim(1E2, 1E4)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90);

上面的代码,不是统计带有类标签的图片个数,而是统计标签中的每个字符,比如'''、'd'、'e'、'r'等。

【问题讨论】:

  • 可能标签值不是列表而是带括号的字符串,如果是这种情况使用 ast.literal)eval() 修复
  • @Ezer K 你能建议修改代码吗?我不清楚您的 cmets,因为我是 Python 编码的新手。谢谢。
  • 尝试添加:import ast; meta['labels'] = [ast.literal_eval(x) for x in meta['labels"].values]
  • 一个更一般的注释,尝试发布文本而不是图片(使用打印或 df.to_clipboard())
  • @Ezer K:谢谢。按照建议编辑帖子。但是,当我使用您的代码行时,它会引发如下值错误: raise ValueError('malformed node or string: ' + repr(node)) ValueError: malformed node or string: <_ast.name object at>跨度>

标签: python pandas dataframe computer-vision


【解决方案1】:

您需要使用literal_eval 将列表形成的字符串解析为真实列表(此外,对于发布的示例,y lims 会使条形消失,因此需要注释),这里是:

import numpy as np
import pandas as pd
import seaborn as sns
from collections import Counter
import ast

train = pd.read_csv('example.csv')    # reading the csv file
meta = pd.DataFrame(train, columns=['filenames', 'labels'])
print(f'Found {len(meta)} images')
meta.sample(9)
meta['labels'] = [ast.literal_eval(x) for x in meta['labels'].values] 
all_labels = [label for lbs in meta['labels'] for label in lbs]
labels_count = Counter(all_labels)
ax = sns.countplot(all_labels, order=[k for k, _ in labels_count.most_common()], log=True)
ax.set_title('Number of images with a class label')
# ax.set_ylim(1E2, 1E4)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90);

【讨论】:

    猜你喜欢
    • 2021-09-04
    • 1970-01-01
    • 2018-03-25
    • 2019-03-01
    • 1970-01-01
    • 2016-08-05
    • 1970-01-01
    • 2020-09-25
    • 2022-12-10
    相关资源
    最近更新 更多