假设您有两个元素形状分别为 (bs,d0,d1) 和 (bs,d0',d1) 的数据集,并且您想要一个元素形状为 (bs,d0+d0',d1) 的新数据集可以使用 tf.Dataset.zip 然后在第二个轴上连接每个元素,如下例所示:
import tensorflow as tf
a = tf.zeros((100,4,8))
b = tf.ones((100,1,8))
d1 = tf.data.Dataset.from_tensor_slices(a)
d1 = d1.batch(16,drop_remainder=True) # elements shape (16,4,8)
d2 = tf.data.Dataset.from_tensor_slices(b)
d2 = d2.batch(16,drop_remainder=True) # elements shape (16,1,8)
d = tf.data.Dataset.zip((d1,d2))
d = d.map(lambda x,y:tf.concat([x,y],axis=-2)) # elements shape (16,4+1,8)
it = iter(d)
x = next(it)
print(x.shape)
print(x)
如果您想将两个具有相同元素形状 (bs,d0,d1) 的数据集堆叠成一个具有元素形状 (bs,d0,d1,2) 的新数据集,您可以压缩这两个数据集,然后放样元素
import tensorflow as tf
a = tf.zeros((100,4,8))
b = tf.ones((100,4,8))
d1 = tf.data.Dataset.from_tensor_slices(a)
d1 = d1.batch(16,drop_remainder=True) # elements shape (16,4,8)
d2 = tf.data.Dataset.from_tensor_slices(b)
d2 = d2.batch(16,drop_remainder=True) # elements shape (16,4,8)
d = tf.data.Dataset.zip((d1,d2))
d = d.map(lambda x,y:tf.stack([x,y],axis=-1)) # elements shape (16,4,8,2)
it = iter(d)
x = next(it)
print(x.shape)
print(x)