【发布时间】:2021-05-07 20:03:41
【问题描述】:
我目前正在处理一个文本分类问题,需要我们将文本分类为四个标签之一。编码后的 y 值应该是 [0,1,2,3] 之一,应该是预测的标签。
但是,这个模型所做的预测似乎在 (0,1) 范围内,我有点困惑?此外,任何人都可以澄清这是 ANN 还是 RNN? TensorFlow 经验为零,但仍在苦苦挣扎……
model = Sequential()
model.add(Dense(16, activation='relu'))
model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
from sklearn.preprocessing import LabelEncoder
#encode the label
label_encoder = LabelEncoder()
y_train=np.array(label_encoder.fit_transform(train_labels))
x_train=np.array(train_features)
y_true=np.array(label_encoder.fit_transform(dev_label))
#fit the model
model.fit(x_train,y_train,epochs=1)
y_pred=model.predict(dev_features)
和错误信息:Classification metrics can't handle a mix of multiclass and continuous-multioutput targets
【问题讨论】:
标签: python tensorflow