【问题标题】:Error calling adapt in TextVectorization Keras在 TextVectorization Keras 中调用适应时出错
【发布时间】:2025-12-04 23:20:03
【问题描述】:

我有以下代码,带有自定义标准化定义。

def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    regex = tf.strings.regex_replace(lowercase, r'[^\w]', ' ')
    regex = tf.strings.regex_replace(regex, ' +', ' ')

    return tf.strings.split(regex)

vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
    standardize=custom_standardization,
    max_tokens=50000,
    output_mode="int",
    output_sequence_length=100,
)

但是当我像这样打电话给adapt时,我得到了下一个错误

vectorize_layer.adapt(['the cat'])
# Error:
InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true. Summarized data: b'the given axis (axis = 2) is not squeezable!'

根据他们的解释,

当使用自定义可调用对象进行拆分时,可调用对象接收到的数据将被挤出第一个维度 - 而不是 [["string to split"], ["another string to split"]],Callable 将看到[“要拆分的字符串”,“要拆分的另一个字符串”]。可调用对象应返回第一个维度包含拆分标记的张量 - 在此示例中,我们应该看到类似 [["string", "to", "split"], ["another", "string", "to “, “分裂”]]。这使得可调用站点与 tf.strings.split() 原生兼容。

Blockquote Source

但我看不到错误在哪里

编辑:我在我的代码中做了一些研究 当我传递像['The other day was raining', 'Please call me later'] 这样的数组时,函数custom_standardization() 返回类似这样的内容

[['the', 'other', 'day', 'was', 'raining'], ['pleasse', 'call', 'me', 'later']]

所以看起来不尊重有相同的形状。为什么它会改变想法?

【问题讨论】:

    标签: python tensorflow keras nlp


    【解决方案1】:

    我参考了您之前分享的document。以下提到了自定义标准化

    当使用自定义可调用进行标准化时,接收到的数据 可调用对象将与传递给该层的完全相同。可调用的 应该返回一个与输入形状相同的张量。

    所以我将return tf.strings.split(regex) 替换为return regex(因为拆分正在改变此处的形状)。请像这样尝试。

    import tensorflow as tf
    
    def custom_standardization(input_data):
        lowercase = tf.strings.lower(input_data)
        regex = tf.strings.regex_replace(lowercase, r'[^\w]', ' ')
        regex = tf.strings.regex_replace(regex, ' +', ' ')
    
        return regex
    
    vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
        standardize=custom_standardization,
        max_tokens=50000,
        output_mode="int",
        output_sequence_length=100,
    )
    
    #checking input shape and output shape are shape or not 
    input = tf.constant([["foo !  @ qux  #bar"], ["qux baz"]])
    print(input)
    print(custom_standardization(input))
    
    vectorize_layer.adapt(["foo qux bar"])
    

    提供gist供参考。

    【讨论】: