【问题标题】:Split and Recombine Tensorflow Dataset拆分和重组 TensorFlow 数据集
【发布时间】:2020-12-29 18:23:42
【问题描述】:

我目前有一个带有多个批次的 tensorflow Dataset(批次数是可变的,但可以被 4 整除)。我想取出每 4 批用作测试,其余的用作训练,但我还没有遇到一个优雅的解决方案。所需结果的简化视觉示例:

Dataset = [b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12]
train = [b1,b2,b3,b5,b6,b7,b9,b10,b11]
test = [b4,b8,b12]

关于Datasets 的训练验证测试拆分的大多数解决方案都使用Dataset.take()Dataset.skip() 的组合,因为他们不介意将数据拆分到中间的某个位置。但是,如果我要使用此解决方案,则需要我计算数据集的大小,使用多个 take()s 和 skip()s 在其上运行一个丑陋的循环,然后收集结果并将它们连接在一起形成一个新的Dataset。有没有更好的方法来选择 tensorflow 数据集中的批次间隔?

【问题讨论】:

    标签: python tensorflow keras tensorflow-datasets


    【解决方案1】:

    解决方案可以通过enumerate()filter()map()的组合来实现,类似于here提供的答案。

    玩具示例:

    list(
        Dataset.from_tensor_slices(np.arange(12))
        .batch(2)
        .as_numpy_iterator()
    )
    

    输出:

    [array([0, 1]),
     array([2, 3]),
     array([4, 5]),
     array([6, 7]),
     array([8, 9]),
     array([10, 11])]
    

    玩具示例的解决方案:

    list(
        Dataset.from_tensor_slices(np.arange(12))
        .batch(2)
        #solution starts here
        .enumerate() 
        .filter(lambda i, data: (i+1)%4 !=0)
        .map(lambda i,data: data)
        #solution ends here
        .as_numpy_iterator()
    )
    

    出来:

    [array([0, 1]), 
     array([2, 3]), 
     array([4, 5]),
     array([8, 9]),
     array([10, 11])]
    

    【讨论】:

      猜你喜欢
      • 2018-12-10
      • 1970-01-01
      • 2022-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-06-25
      • 2017-11-14
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多