【问题标题】:Beam Search Decoder Tensorflow 2.0束搜索解码器 Tensorflow 2.0
【发布时间】:2021-06-30 10:42:02
【问题描述】:

我希望在 Tensorflow 2.0 alpha 中使用注意力和束搜索来实现对神经网络进行排序的序列。虽然他们网站上的教程非常有用,但由于 contrib 库已弃用,我无法找出实现光束搜索的最佳方法 - 谁能指出我正确的方向?

我尝试使用 TF2.0s 升级脚本将我的 tensorflow 1.X 束搜索升级到 2.0,但它不支持 contrib 库。

这就是光束搜索代码查找 1.x 的方式

decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=decoder_cell,
                    embedding=self.embeddings,
                    start_tokens=tf.fill([self.batch_size], tf.constant(2)),
                    end_token=tf.constant(3),
                    initial_state=initial_state,
                    beam_width=self.beam_width,
                    output_layer=self.projection_layer
                )
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder, output_time_major=True, maximum_iterations=summary_max_len, scope=decoder_scope)
self.prediction = tf.transpose(outputs.predicted_ids, perm=[1, 2, 0])

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    Tensorflow 1.x 中很少有 API 被移至 Tensorflow 2.x 中的不同 API。 Tf.contrib 就是这样一个库,它部分迁移到了 Tensorflow 插件。

    对于tf.contrib.seq2seq.BeamSearchDecoder 移动到tfa.seq2seq.BeamSearchDecoder in TFv2.x.

    tfa.seq2seq.BeamSearchDecoder(
        cell: tf.keras.layers.Layer,
        beam_width: int,
        embedding_fn: Optional[Callable] = None,
        output_layer: Optional[tf.keras.layers.Layer] = None,
        length_penalty_weight: tfa.types.FloatTensorLike = 0.0,
        coverage_penalty_weight: tfa.types.FloatTensorLike = 0.0,
        reorder_tensor_arrays: bool = True,
        **kwargs
    )
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-05-18
      • 2020-09-04
      相关资源
      最近更新 更多