【问题标题】:How do I determine the binary class predicted by a convolutional neural network on Keras?如何确定 Keras 上的卷积神经网络预测的二元类?
【发布时间】:2019-01-31 18:31:37
【问题描述】:

我正在构建一个 CNN 来对 Keras 进行情绪分析。 一切正常,模型已经过训练,可以投入生产了。

但是,当我尝试使用 model.predict() 方法对新的未标记数据进行预测时,它只会输出相关的概率。我尝试使用 np.argmax() 方法,但它总是输出 0,即使它应该是 1(在测试集上,我的模型达到了 80% 的准确率)。

这是我预处理数据的代码:

# Pre-processing data
x = df[df.Sentiment != 3].Headlines
y = df[df.Sentiment != 3].Sentiment

# Splitting training, validation, testing dataset
x_train, x_validation_and_test, y_train, y_validation_and_test = train_test_split(x, y, test_size=.3,
                                                                                      random_state=SEED)
x_validation, x_test, y_validation, y_test = train_test_split(x_validation_and_test, y_validation_and_test,
                                                                  test_size=.5, random_state=SEED)

tokenizer = Tokenizer(num_words=NUM_WORDS)
tokenizer.fit_on_texts(x_train)

sequences = tokenizer.texts_to_sequences(x_train)
x_train_seq = pad_sequences(sequences, maxlen=MAXLEN)

sequences_val = tokenizer.texts_to_sequences(x_validation)
x_val_seq = pad_sequences(sequences_val, maxlen=MAXLEN)

sequences_test = tokenizer.texts_to_sequences(x_test)
x_test_seq = pad_sequences(sequences_test, maxlen=MAXLEN)

这是我的模型:

MAXLEN = 25
NUM_WORDS = 5000
VECTOR_DIMENSION = 100

tweet_input = Input(shape=(MAXLEN,), dtype='int32')

tweet_encoder = Embedding(NUM_WORDS, VECTOR_DIMENSION, input_length=MAXLEN)(tweet_input)

# Combinating n-gram to optimize results
bigram_branch = Conv1D(filters=100, kernel_size=2, padding='valid', activation="relu", strides=1)(tweet_encoder)
bigram_branch = GlobalMaxPooling1D()(bigram_branch)
trigram_branch = Conv1D(filters=100, kernel_size=3, padding='valid', activation="relu", strides=1)(tweet_encoder)
trigram_branch = GlobalMaxPooling1D()(trigram_branch)
fourgram_branch = Conv1D(filters=100, kernel_size=4, padding='valid', activation="relu", strides=1)(tweet_encoder)
fourgram_branch = GlobalMaxPooling1D()(fourgram_branch)
merged = concatenate([bigram_branch, trigram_branch, fourgram_branch], axis=1)

merged = Dense(256, activation="relu")(merged)
merged = Dropout(0.25)(merged)
output = Dense(1, activation="sigmoid")(merged)

optimizer = optimizers.adam(0.01)

model = Model(inputs=[tweet_input], outputs=[output])
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=['accuracy'])
model.summary()

# Training the model
history = model.fit(x_train_seq, y_train, batch_size=32, epochs=5, validation_data=(x_val_seq, y_validation))

我还尝试将最终 Dense 层的激活次数从 1 更改为 2,但出现错误:

Error when checking target: expected dense_12 to have shape (2,) but got array with shape (1,)

【问题讨论】:

  • 欢迎来到 Stack Overflow!输出是单个激活,因此它似乎是单个二元类的概率。只需取一个操作点阈值(例如 0.5),如果概率相等或更大,则预测 true。事实上,这个网站上很可能还有一个对您有用的问题,但目前可能很难找到。

标签: python machine-learning keras deep-learning text-classification


【解决方案1】:

你正在做二进制分类。所以你有一个密集层,由一个单元组成,激活函数为sigmoid。 Sigmoid 函数在 [0,1] 范围内输出一个值,该值对应于给定样本属于正类(即第一类)的概率。低于 0.5 的所有内容都标记为 0(即负类),高于 0.5 的所有内容都标记为 1。因此,要找到预测的类,您可以执行以下操作:

preds = model.predict(data)
class_one = preds > 0.5

class_one 的真实元素对应于标记为 1 的样本(即正类)。

奖励:要找到预测的准确性,您可以轻松地将 class_one 与真实标签进行比较:

acc = np.mean(class_one == true_labels)

请注意,我假设true_labels 由零和一组成。


此外,如果您的模型是使用 Sequential 类定义的,那么您可以轻松使用predict_classes 方法:

pred_labels = model.predict_classes(data)

但是,由于您使用 Keras 函数式 API 来构建模型(在我看来,这样做是一件非常好的事情),因此您不能使用 predict_classes 方法,因为它对于此类模型的定义不明确.

【讨论】:

  • 感谢您的回答!现在更清楚了。但是我认为我有一个更大的问题。当我尝试使用模型预测未标记的数据时,即使数据明显是否定的,我也总是得到非常肯定的答案。我首先认为我的模型过度拟合了训练数据。因此,我尝试对测试中的文本进行分类。当我评估模型时,它被认为是高度消极的,但当我尝试预测它时,它是高度积极的。我使用与训练我的模型的标记器相同的标记器。
  • @RFTexas 你能否澄清一下“评估模型”和“尝试预测它”是什么意思?对于后者,我猜你使用predict 方法,但我不明白你在这里所说的“评估”是什么意思。
  • 首先我训练我的模型并在验证集上对其进行优化。然后我使用“评估”方法来查看我的模型在测试集上的表现。当我想出令人满意的准确性时,我想使用该模型来预测新数据。问题是,当测试集中像“IPO 失败后价格垂直下降”这样的句子时,它被我的模型标记为负数(显然),概率约为 0。但是当我尝试用我的模型,它说它是高度正的(大约 1)。
  • @RFTexas 我无法理解这些部分:“...它被 my model 标记为 negative...”和“ ...尝试用我的模型进行标记,它说它是高度积极的”。模型如何在给定相同数据的情况下同时预测负数和正数?
  • 这正是我想要弄清楚的!有点奇怪!一开始我以为是因为分词器。
猜你喜欢
  • 1970-01-01
  • 2019-08-18
  • 2019-05-22
  • 2016-09-03
  • 2015-04-15
  • 2018-02-06
  • 2017-08-06
  • 2018-03-19
  • 1970-01-01
相关资源
最近更新 更多