【问题标题】:TF2.0 Data API get n_i samples from each class labelTF2.0 Data API 从每个类标签中获取 n_i 个样本
【发布时间】:2019-12-23 02:57:36
【问题描述】:

我必须使用 TF2 Keras 模型将形状为 32x32 的输入分为 3 类。我的训练集有 7000 个示例

>>> X_train.shape # (7000, 32, 32)
>>> Y_train.shape # (7000, 3)

每个类的示例数量各不相同(例如,class_0 有 ~2500 个示例,而 class_1 有 ~800 个等)

我想使用 tf.data API 创建一个数据集对象,该对象返回批量训练数据,没有。来自[n_0, n_1, n_2] 指定的每个类的示例。

我想从每个类别中随机抽取这些n_i 样本,并从X_train, Y_train 替换

例如,如果我调用 get_batch([100, 150, 125]),它应该从 class_0 的 X_batch 返回 100 个随机样本,从 class_1 返回 150 个随机样本,从 class_2 返回 125 个随机样本。

如何使用 TF2.0 数据 API 实现这一点,以便可以使用它来训练 Keras 模型?

【问题讨论】:

    标签: python tensorflow keras tensorflow-datasets tensorflow2.0


    【解决方案1】:

    Keras 的train_test_split 实际上有一个参数。虽然它不允许您选择确切数量的样本,但它会从类中均匀地选择它们。

    X_train_stratified, X_test_stratified, y_train_strat, y_test_strat = train_test_split(X_train, y_train, test_size=0.2, stratify=y)
    

    如果你想做交叉验证,你也可以使用stratified shuffle split

    希望我能正确理解你的问题

    【讨论】:

      【解决方案2】:

      一种可能的方法是如下进行:

      1. X_trainY_train 中的数据加载到单个tf.data 数据集中,以便我们确保每个X 都与正确的Y 匹配
      2. .shuffle() 然后使用 filter() 将数据集拆分为每个 n_i
      3. 编写我们的get_batch 函数以从每个数据集中返回正确数量的样本,shuffle() 样本然后将其拆分回XY

      类似这样的:

      # 1: Load the data into a Dataset
      raw_data = tf.data.Dataset.zip(
          (
              tf.data.Dataset.from_tensor_slices(X_train),
              tf.data.Dataset.from_tensor_slices(Y_train)
          )
        ).shuffle(7000)
      
      
      # 2: Split for each category
      def get_filter_fn(n):
        def filter_fn(x, y):
          return tf.equal(1.0, y[n])
        return filter_fn
      
      n_0s = raw_data.filter(get_filter_fn(0))
      n_1s = raw_data.filter(get_filter_fn(1))
      n_2s = raw_data.filter(get_filter_fn(2))
      
      # 3:
      def get_batch(n_0,n_1,n_2):
        sample = n_0s.take(n_0).concatenate(n_1s.take(n_1)).concatenate(n_2s.take(n_2))
        shuffled = sample.shuffle(n_0 + n_1 + n_2)
        return shuffled.map(lambda x,y: x),shuffled.map(lambda x,y: y) 
      
      

      所以现在我们可以这样做了:

      x_batch, y_batch = get_batch(100, 150, 125)
      

      请注意,我在这里使用了一些可能浪费的操作,追求一种我认为直观且直接的方法(特别是针对过滤操作阅读 raw_data 数据集 3 次),因此我没有声称这是最有效的方法完成你所需要的,但对于像你描述的那样适合内存的数据集,我相信这种低效率将可以忽略不计

      【讨论】:

        猜你喜欢
        • 2018-07-03
        • 2021-06-01
        • 1970-01-01
        • 1970-01-01
        • 2022-11-18
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多