【发布时间】:2021-04-28 03:01:21
【问题描述】:
版本:张量流 2.3.0、numpy 1.18.5、python 3.8.2
我想使用我的 TensorFlow 模型中的第一层删除输入张量的一些选定切片。例如,我有一个(180, 90, 25) 的输入形状(其中 180 是批量大小),我想从最后一个维度中删除索引列表indices = [3, 4, 5, 6, 7, 22, 23, 24],以便在输入张量上调用此层之后,我会得到一个形状为(180, 90, 25 - len(indices)) 的张量,其中每个选定的(180, 90) 形状的张量切片都已通过索引最后一个维度来删除。
目前,我正在使用这一层:
class RemoveSelectedIndices(tf.keras.layers.Layer):
def __init__(self, indices=[3,4,5,6,7,22,23,24]):
super(RemoveSelectedIndices, self).__init__(name="RemoveSelectedIndices")
self.indices = self.add_weight(name="indices", shape=len(indices), dtype=tf.int32, trainable=False,
initializer=lambda *args, **kwargs: indices)
def build(self, input_shape):
pass
def call(self, input_tensor):
X = tf.unstack(input_tensor, num=input_tensor.shape[-1], axis=2) # list of 25 (180, 90)-shaped slices
indices = sorted(list(self.indices.value().numpy()))
for i in reversed(indices):
del X[i]
X = tf.stack(X, axis=2) # restacking the list back together
return X
当我测试它时(通过创建一个 numpy 数组并使用 tf.convert_to_tensor 然后在张量上调用该层),这工作得非常好,但是当我尝试使用该层作为第一层来构建模型时,我收到一个错误:
import tensorflow as tf
from tensorflow.keras.layers import Input
inputs = Input(shape=(90, 25))
X = RemoveSelectedIndices()(inputs)
# gives me AttributeError: 'Tensor' object has no attribute 'numpy'
# references the line indices = sorted(list(self.indices.value().numpy()))
为什么会发生这种情况,有什么办法可以解决吗?
(注意:我知道我可以对数据本身执行此操作,但数据集很大,除非必须,否则我宁愿不要过多地处理数据集。)
提前谢谢你!
【问题讨论】:
标签: python tensorflow machine-learning keras deep-learning