【问题标题】:How to split a tensorflow dataset into N datasets with shuffling如何通过改组将张量流数据集拆分为 N 个数据集
【发布时间】:2025-11-25 14:15:01
【问题描述】:

我有一个 tensorflow 数据集ds,我想将其拆分为 N 个数据集,它们的并集是原始数据集,并且它们之间不共享样本。 我试过了:

ds_list = [ds.shard(N,index=i) for i in range(N)]

但不幸的是,这不是随机的:每个新数据集总是会从原始数据集中获得相同的样本。例如,ds_list[0] 的样本数为 0,N,2N,3N...,而 ds_list[1] 的样本数为 1,N+1,2N+1,3N+1... 有没有办法将原始数据集随机细分为相同大小的数据集?

不幸的是,简单地在之前洗牌并不能解决问题:

import tensorflow as tf
import math

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])

N=2
ds = ds.shuffle(20)
ds_list = [ds.shard(N,index=i) for i in range(N)]


for ds in ds_list:
    shard_set = sorted(set(list(ds.as_numpy_iterator())))
    print(shard_set)

输出:

    [3, 5, 6, 8, 11, 12, 14, 15, 19, 20]
    [1, 2, 4, 5, 6, 7, 8, 14, 15, 20]

同:

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])
N=2
ds_list = []
ds = ds.shuffle(20)
size = ds.__len__()
sub = math.floor(size/N)
for n in range(N):
    ds_sub = ds.take(sub)
    remainder = ds.skip(sub)
    ds_list.append(ds_sub)
    ds = remainder  

for ds in ds_list:
    shard_set = sorted(set(list(ds.as_numpy_iterator())))
    print(shard_set)

【问题讨论】:

    标签: tensorflow tensorflow2.0


    【解决方案1】:

    也许(对于 N 个分片):

    ds_list = []
    ds = ds.shuffle()
    size = ds.__len__()
    sub = floor(size/N)
    for n in range(N):
        ds_sub = ds.take(sub)
        remainder = ds.skip(sub)
        ds_list.append(ds_sub)
        ds = remainder  
    

    【讨论】:

    • 感谢您的回复。您提出的解决方案不起作用,因为它不能保证所有样本都在一个且只有一个“拆分”数据集中
    • 我相信它确实如此,因为在循环结束时 ds 被设置为余数。所以添加到 ds_sub 的点不会再被考虑。
    • 我猜这是因为洗牌,因为采取和跳过操作可能涉及不同顺序的样本。随意使用我在主要问题中提供的代码。
    • 之前没有发生随机播放。然后,您继续获取并跳过,直到分配了所有数据。
    • 你测试过你的代码吗?只需尝试我在主要问题中提供的代码。你会看到你会得到共享一些样本的分片
    【解决方案2】:

    你可以先打乱数据集,然后分片:

    ds = ds.shuffle(buffer_size)
    ds_list = [ds.shard(N,index=i) for i in range(N)]
    

    这里buffer_size是TF用于排序的缓冲区大小。如果数据集的大小很小,您可以将示例总数传递为buffer_size。否则,可以放入内存的较小数字(例如 100)将起作用。

    【讨论】:

    • 感谢您的回复。您提出的解决方案不起作用,因为它不能保证所有样本都在一个且只有一个“拆分”数据集中
    • 我认为确实如此,因为 shuffle 不会多次考虑相同的示例(因为您没有重复数据集)。你能提供一个例子吗?如果您在 shuffle 之前使用 repeat ,那么您所说的是有道理的,但由于这里没有发生这种情况,我认为这应该可行。
    • 嗨,这里有一个可视化测试的可能代码。有些样本在两个数据集中,有些则不在。我编辑了主要问题