【发布时间】:2019-10-31 09:11:17
【问题描述】:
我研究了 tensorflow 中不同的切片方式,即tf.gather 和tf.gather_nd。
在 tf.gather 中,它只是对一个维度进行切片,在 tf.gather_nd 中,它只接受一个 indices 应用于输入张量。
我需要的是不同的,我想使用两个不同的张量对输入张量进行切片;一个切片在行上,第二个切片在列上,它们的形状不一定相同。
例如:
假设这是我想要提取其中一部分的输入张量。
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
第二个是:
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
第三张量:
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
现在,我想使用rows_tf 和columns_tf 对input_tf 进行切片。在行中索引[1 2 5],在columns_tf 中索引[1]。同样,[1 2 5] 和 [2] 在 columns_tf 中的行。
或者,[1 4 6] 和 [2]。
总的来说,rows_tf 中的每个索引,与columns_tf 中的相同索引都会提取input_tf 的一部分。
因此,预期的输出将是:
[[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
例如,这里的第一行 [8.3356, 0., 8.457685 ] 正在使用
rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)
有几个关于 tensorflow 切片的问题,尽管他们使用了 tf.gather 或 tf.gather_nd 和 tf.stack,但没有给出我想要的输出。
无需提及,在numpy 中,我们可以通过调用:input_tf[rows_tf, columns_tf] 轻松做到这一点。
我也看过这个高级索引,它试图模拟 numpy 中可用的高级索引,但它仍然不像 numpy 灵活https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb
这是我尝试过的不正确的:
tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)
这段代码的维度输出是(8, 1, 3, 8),完全不正确。
提前致谢!
【问题讨论】:
-
您应该编辑您的问题,以便正确格式化所有常量(添加
,) -
@DSC 你是对的,我现在就做,谢谢
-
为什么您对您提到的“收集”操作的输出不满意?听起来它可以工作。是因为它返回它变平吗?如果是这样,您可以知道“rows_tf”和“columns_tf”的尺寸来重塑它
-
在stackoverflow.com/questions/56640222/… 中添加
scatter_idx后,您应该可以使用tf.gather_nd(params=input_tf, indices=scatter_idx)后跟tf.reshape来获得所需的形状。 -
我复制了下面的完整代码作为答案,您可以使用其他线程中任何其他答案的类似方式来获取
sparse_indices。
标签: python tensorflow keras slice