【发布时间】:2021-05-31 20:21:46
【问题描述】:
我有一个名为 tensor 的 rank-3 张量,形状为 [batch_size, axis_1, axis_2],并希望将其沿第一个轴拆分为 batch_size 切片,如下所示:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
很遗憾,这不起作用,因为batch_size 的值在构建图的过程中是未知的。
我该如何解决这个问题?
我收到此错误:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
奇怪的是,尝试在其他 TensorFlow 函数中使用 batch_size 似乎有效:
tensor = tf.reshape(tensor, [batch_size, -1])
尽管batch_size 的值在图形构建过程中是未知的,但工作正常。
问题是tf.split() 的问题吗?
【问题讨论】:
标签: python tensorflow