【发布时间】:2021-02-13 17:22:02
【问题描述】:
我已经从 CIFAR10 加载了我的训练集和验证集,如下所示:
train = tfds.load('cifar10', split='train[:90%]', shuffle_files=True)
validation = tfds.load('cifar10', split='train[-10%:]', shuffle_files=True)
我已经为我的 CNN 创建了架构
model = ...
现在我正在尝试使用 model.fit() 来训练我的模型,但我不知道如何从我的对象中分离出“图像”和“标签”。训练和验证如下所示:
print(train) # same layout as the validation set
<_OptionsDataset shapes: {id: (), image: (32, 32, 3), label: ()}, types: {id: tf.string, image: tf.uint8, label: tf.int64}>
我的幼稚方法是这样,但那些 OptionsDatasets 不能下标。
history = model.fit(train['image'], train['label'], epochs=100, batch_size=64, validation_data=(validation['image'], test['label'], verbose=0)
【问题讨论】:
标签: python tensorflow machine-learning conv-neural-network