【问题标题】:Tensorflow Error: "Label IDs must < n_classes", but my Label IDs appear to meet this requirement alreadyTensorflow 错误:“标签 ID 必须 < n_classes”,但我的标签 ID 似乎已经满足此要求
【发布时间】:2018-08-07 17:46:56
【问题描述】:

我正在尝试创建一个 Python 3 程序来使用 Tensorflow 将句子分类。但是,当我尝试运行我的代码时,我遇到了一系列非常冗长的错误。以下错误似乎是我的问题的基础:

InvalidArgumentError: 断言失败:[标签 ID 必须 1...] [y (linear/head/assert_range/Const:0) = ] 2

我正在使用 Scikit-Learn 的 LabelEncoder() 方法创建标签 ID,应该满足这个要求;他们的documentation page 说,“使用0n_classes-1 之间的值编码标签。

我要运行的代码是:

import tensorflow as tf
import pandas as pd
from sklearn import preprocessing
from sklearn.model_selection import train_test_split


data_df = pd.read_csv('data.csv') #data.csv has 2 columns: "Category", and "Description"

features = data_df.drop('Category', axis=1) #drop Category column
lab_enc = preprocessing.LabelEncoder()
labels = lab_enc.fit_transform(data_df['Category']) #Encode labels with value between 0 and n_classes-1
labels = pd.Series(labels) #pandas_input_func needs the labels in Series format    

features_train, features_test, labels_train, labels_test = train_test_split(features, labels, test_size=0.3, random_state=101)


description = tf.feature_column.categorical_column_with_hash_bucket('Description', hash_bucket_size=1000)
feat_cols = [description]

input_func = tf.estimator.inputs.pandas_input_fn(x=features_train, y=labels_train, batch_size=100, num_epochs=None, shuffle=True)

model = tf.estimator.LinearClassifier(feature_columns=feat_cols)
model.train(input_fn=input_func, steps=1000)

我使用的 data.csv 文件是一小组乱码测试数据:

我不知道如何继续。我只发现一篇帖子引用了类似的问题here,但如果我理解正确的话,该用户的问题似乎与我的问题不同。

非常感谢任何见解!

【问题讨论】:

  • 我认为您需要与类别相同数量的标签。不知道如何设置。

标签: python pandas tensorflow machine-learning scikit-learn


【解决方案1】:

试试这个:

# Explicitly specify the number of classes, e.g. 10
model = tf.estimator.LinearClassifier(feature_columns=feat_cols, n_classes=10)

默认值n_classes=2,内部表示tensorflow使用sigmoid交叉熵损失。设置类数将使其成为softmax交叉熵。

【讨论】:

  • 非常感谢!原来是这样。我想我没想到分类器需要以这种方式明确定义许多类。你能告诉我使 n_classes 比 n_classes 实际上有什么好处吗?我做了 n_classes=len(lab_enc.classes_) 因为这会给我唯一标签的确切数量。
  • 这是因为 tensorflow 在读取数据之前正在构建图。它做的第一件事就是选择分类器——二分类或多分类。我认为这里真正的问题是令人困惑的错误消息。
  • 啊,我明白了。我忘记了 tensorflow 在构建计算图时使用占位符等操作。非常感谢您的帮助!
猜你喜欢
  • 2018-01-30
  • 2021-04-09
  • 1970-01-01
  • 2021-10-14
  • 2016-04-03
  • 2013-01-17
  • 2010-11-03
  • 2016-10-14
  • 1970-01-01
相关资源
最近更新 更多