【发布时间】:2021-03-10 13:59:08
【问题描述】:
我正在尝试设置从 tf.py_function 调用返回的 numpy 数组的形状。函数返回一个数组元组及其形状,所以我可以在tf.function 中设置它的形状。但是我不能在 set_shape 调用中使用返回的形状数组。我尝试将其转换为 int32 或通过索引等访问它,但所有结果都会导致 TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'strided_slice:0' shape=<unknown> dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>' 错误。
@tf.function
def samples(img, mask) -> tuple:
xs, xs_shape = tf.py_function(process,[128, 128], [tf.float32, tf.int32])
xs.set_shape(tf.TensorShape(xs_shape))
return xs
ds = ds.map(samples)
【问题讨论】:
标签: python tensorflow keras tensorflow2.0