【发布时间】:2019-12-23 02:57:36
【问题描述】:
我必须使用 TF2 Keras 模型将形状为 32x32 的输入分为 3 类。我的训练集有 7000 个示例
>>> X_train.shape # (7000, 32, 32)
>>> Y_train.shape # (7000, 3)
每个类的示例数量各不相同(例如,class_0 有 ~2500 个示例,而 class_1 有 ~800 个等)
我想使用 tf.data API 创建一个数据集对象,该对象返回批量训练数据,没有。来自[n_0, n_1, n_2] 指定的每个类的示例。
我想从每个类别中随机抽取这些n_i 样本,并从X_train, Y_train 替换
例如,如果我调用 get_batch([100, 150, 125]),它应该从 class_0 的 X_batch 返回 100 个随机样本,从 class_1 返回 150 个随机样本,从 class_2 返回 125 个随机样本。
如何使用 TF2.0 数据 API 实现这一点,以便可以使用它来训练 Keras 模型?
【问题讨论】:
标签: python tensorflow keras tensorflow-datasets tensorflow2.0