【问题标题】:Tensorflow Dataset Mask Sequence for Evaluation用于评估的 TensorFlow 数据集掩码序列
【发布时间】:2026-01-07 23:50:02
【问题描述】:

问题:

给定可变长度输入,准确度指标不正确,因为短向量被填充到更长的向量。

使用来自kerasMasking 层解决了这个问题,通过对所有零值应用掩码,但由于我的序列自然包含零,而且超长(50,000 个标记),这会使训练速度减慢 50 倍!

详情

我有一个以tf.data.Dataset 表示的数据集,其中每个示例包含 3 个属性:

  1. src - 输入序列
  2. tgt - 类 ID 输出序列
  3. tokens 序列中的令牌数

我想训练一个序列标记模型。

为了在 Keras 中使用它,我知道我的数据集中只需要一个 xy,所以我映射:

dataset = dataset.map(lambda d: (d['src'], d['tgt']))

然后我传递给一个 keras 模型:

model = tf.keras.Sequential([
  tf.keras.layers.LSTM(hidden_size, return_sequences=True),
  tf.keras.layers.Dense(2),
  tf.keras.layers.Activation(activations.softmax)
])

有没有办法在图形模式下应用掩码作为数据集管道的一部分? (掩码为tf.sequence_mask(datum['tokens'])

替代解决方案

或者,如果我也可以在我的数据集中传递tokens 的数量,并创建我自己的评估指标来应用掩码,那么传递一个未屏蔽的序列就没有问题。

我找不到如何传递包含 3 个项目而不是 2 个项目的数据集,keras 顺序模型似乎不允许这样做。

【问题讨论】:

  • 那为什么不使用顺序模型以外的东西呢?
  • @NicolasGervais 请建议在这种情况下使用什么,然后,我不确定你的意思
  • 函数式api,子类化api...
  • 谢谢,我已经尝试了一个小时,没有任何进展。您能否以答案为例进行详细说明?
  • 如果可以帮助您,我可以将具有多个输入的功能齐全的模型粘贴到玩具数据集上。然后,您可以根据自己的喜好对其进行调整

标签: python tensorflow keras


【解决方案1】:

希望这会有所帮助。您可以看到它不仅需要输入/输出,还需要第二个输入。我还使用超过 1 个输出,因此您可以将信息传入和传出神经网络。我希望您会对这个示例的可定制性感到振奋。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn.datasets import load_iris
from functools import partial
tf.keras.backend.set_floatx('float64')
iris, target = load_iris(return_X_y=True)

X = iris[:, :3]
y = iris[:, 3]
z = target

onehot = partial(tf.one_hot, depth=3)

dataset = tf.data.Dataset.from_tensor_slices((X, y, z)).shuffle(150)

train_ds = dataset.take(120).shuffle(10).\
    batch(8).map(lambda a, b, c: (a, b, onehot(c)))

test_ds = dataset.skip(120).take(30).shuffle(10).\
    batch(8).map(lambda a, b, c: (a, b, onehot(c)))

next(iter(train_ds))

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d0 = Dense(64, activation='relu')
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(1)
        self.d3 = Dense(3)

    def call(self, x, training=None, **kwargs):
        x = self.d0(x)
        x = self.d1(x)
        a = self.d2(x)
        b = self.d3(x)
        return a, b

model = MyModel()

loss_obj_reg = tf.keras.losses.MeanAbsoluteError()
loss_obj_cat = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

loss_reg_train = tf.keras.metrics.Mean(name='regression loss')
loss_cat_train = tf.keras.metrics.Mean(name='categorical loss')

loss_reg_test = tf.keras.metrics.Mean(name='regression loss')
loss_cat_test = tf.keras.metrics.Mean(name='categorical loss')

train_acc = tf.keras.metrics.CategoricalAccuracy()
test_acc = tf.keras.metrics.CategoricalAccuracy()


@tf.function
def train_step(inputs, y_reg, y_cat):
    with tf.GradientTape() as tape:
        pred_reg, pred_cat = model(inputs, training=True)
        reg_loss = loss_obj_reg(y_reg, pred_reg)
        cat_loss = loss_obj_cat(y_cat, pred_cat)

    gradients = tape.gradient([reg_loss, cat_loss], model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    loss_reg_train(reg_loss)
    loss_cat_train(cat_loss)

    train_acc(y_cat, pred_cat)


@tf.function
def test_step(inputs, y_reg, y_cat):
    pred_reg, pred_cat = model(inputs, training=False)
    reg_loss = loss_obj_reg(y_reg, pred_reg)
    cat_loss = loss_obj_cat(y_cat, pred_cat)

    loss_reg_test(reg_loss)
    loss_cat_test(cat_loss)

    test_acc(y_cat, pred_cat)


for epoch in range(250):

    loss_reg_train.reset_states()
    loss_cat_train.reset_states()

    loss_reg_test.reset_states()
    loss_cat_test.reset_states()

    train_acc.reset_states()
    test_acc.reset_states()

    for xx, yy, zz in train_ds:
        train_step(xx, yy, zz)

    for xx, yy, zz in test_ds:
        test_step(xx, yy, zz)

    template = 'Epoch {:3} ' \
               'MAE {:5.3f} TMAE {:5.3f} ' \
              'Entr {:5.3f} TEntr {:5.3f} ' \
               'Acc {:7.2%} TAcc {:7.2%}'

    print(template.format(epoch+1,
                        loss_reg_train.result(),
                        loss_reg_test.result(),
                        loss_cat_train.result(),
                        loss_cat_test.result(),
                        train_acc.result(),
                        test_acc.result()))

如果您需要任何信息,请告诉我。

【讨论】: