【问题标题】:Using batched input with tf.math.invert_permutation将批处理输入与 tf.math.invert_permutation 一起使用
【发布时间】:2021-03-17 12:42:17
【问题描述】:

我有一个张量,其中包含一组整数 0time-1 的排列,例如有形状

[batch,time]

现在我想反转所有这些排列以获得相同形状的张量。 我知道这可以使用tf.math.invert_permutation 来完成[time] 形状的单个张量,但该函数不支持批量输入。如果输入张量多于一维,则会出错。

如何使tf.math.invert_permutation 与批处理输入一起工作?

【问题讨论】:

    标签: python tensorflow permutation


    【解决方案1】:

    使用tf.data.Dataset API 的一种方法。

    以下是我将如何实现它。

    【讨论】:

    • 感谢您的回答。除了使用 Dataset API,我还可以使用 tf.map_fn。其实好主意!我仍然有点担心性能,但它很容易实现并且让我现在就开始。我正在使用:tf.reshape(tf.map_fn(fn=lambda x: tf.math.invert_permutation(x), elems=tf.reshape(x, [-1, x.shape[len(x.shape)-1]])), x.shape)
    • 我实际上并不确定性能数据,但我想说使用输入数据管道的首选方法是根据 Tensorflow 使用 tf.data.Dataset API。
    • 我对此做了一些测试,就我而言,在这里使用map 非常慢。相反,我最终将tf.rangetf.scatter_nd 结合使用,并在此答案的末尾实现了简单的numpy 解决方案:stackoverflow.com/a/25535723/2766231
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-04-01
    • 2013-04-10
    • 2023-04-10
    • 2021-03-10
    • 2022-10-14
    • 1970-01-01
    相关资源
    最近更新 更多