【发布时间】:2025-11-25 14:15:01
【问题描述】:
我有一个 tensorflow 数据集ds,我想将其拆分为 N 个数据集,它们的并集是原始数据集,并且它们之间不共享样本。
我试过了:
ds_list = [ds.shard(N,index=i) for i in range(N)]
但不幸的是,这不是随机的:每个新数据集总是会从原始数据集中获得相同的样本。例如,ds_list[0] 的样本数为 0,N,2N,3N...,而 ds_list[1] 的样本数为 1,N+1,2N+1,3N+1... 有没有办法将原始数据集随机细分为相同大小的数据集?
不幸的是,简单地在之前洗牌并不能解决问题:
import tensorflow as tf
import math
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])
N=2
ds = ds.shuffle(20)
ds_list = [ds.shard(N,index=i) for i in range(N)]
for ds in ds_list:
shard_set = sorted(set(list(ds.as_numpy_iterator())))
print(shard_set)
输出:
[3, 5, 6, 8, 11, 12, 14, 15, 19, 20]
[1, 2, 4, 5, 6, 7, 8, 14, 15, 20]
同:
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])
N=2
ds_list = []
ds = ds.shuffle(20)
size = ds.__len__()
sub = math.floor(size/N)
for n in range(N):
ds_sub = ds.take(sub)
remainder = ds.skip(sub)
ds_list.append(ds_sub)
ds = remainder
for ds in ds_list:
shard_set = sorted(set(list(ds.as_numpy_iterator())))
print(shard_set)
【问题讨论】: