【发布时间】:2020-09-17 09:19:18
【问题描述】:
我想在 TensorFlow 2.x 中的一个模型中为输入张量的切片分配一些值(我正在使用 2.2,但准备接受 2.1 的解决方案)。 我正在尝试做的一个非工作模板是:
import tensorflow as tf
from tensorflow.keras.models import Model
class AddToEven(Model):
def call(self, inputs):
outputs = inputs
outputs[:, ::2] += inputs[:, ::2]
return outputs
当然,在构建这个 (AddToEven().build(tf.TensorShape([None, None]))) 时,我收到以下错误:
TypeError: 'Tensor' object does not support item assignment
我可以通过以下方式实现这个简单的示例:
class AddToEvenScatter(Model):
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
n = tf.shape(inputs)[-1]
update_indices = tf.range(0, n, delta=2)[:, None]
scatter_nd_perm = [1, 0]
inputs_reshaped = tf.transpose(inputs, scatter_nd_perm)
outputs = tf.tensor_scatter_nd_add(
inputs_reshaped,
indices=update_indices,
updates=inputs_reshaped[::2],
)
outputs = tf.transpose(outputs, scatter_nd_perm)
return outputs
(您可以通过以下方式进行完整性检查:
model = AddToEvenScatter()
model.build(tf.TensorShape([None, None]))
model(tf.ones([1, 10]))
)
但是正如您所见,编写起来非常复杂。这仅适用于一维(+批量)张量的静态更新次数(此处为 1)。
我想要做的是更多的参与,我认为用tensor_scatter_nd_add 编写它将会是一场噩梦。
当前很多关于该主题的 QA 都涵盖了变量但不包括张量的情况(参见例如 this 或 this)。 提到here 确实 pytorch 支持这一点,所以我很惊讶地看到最近没有任何 tf 成员对此主题的回应。 This answer 并没有真正帮助我,因为我需要生成某种蒙版,这也会很糟糕。
因此,问题是:如何在没有tensor_scatter_nd_add 的情况下有效地进行切片分配(计算方面、内存方面和代码方面)?诀窍是我希望它尽可能动态,这意味着inputs 的形状可以是可变的。
(对于任何好奇的人,我正在尝试将this code 翻译成 tf)。
此问题最初发布于in a GitHub issue。
【问题讨论】:
-
由于没有更好的解决方案,我使用
tensor_scatter_nd_update为此创建了一个模块。从长远来看,希望我不必诉诸于此。但与此同时,如果有人想使用它,您可以查看here。
标签: python tensorflow keras tensorflow2.x