【发布时间】:2021-04-05 07:00:18
【问题描述】:
我对 Tensorflow 有点陌生,但在运行这个简单的 CNN 时遇到了麻烦。 我将图像分成每个类的单独目录,我使用 image_dataset_from_directory 将其加载到 train_dataset。
从文档中,这应该产生一个元组(图像,标签),其中图像具有形状(batch_size,image_size[0],image_size[1],num_channels),标签是一个 float32 形状的张量(batch_size,num_classes )。 num_channels 是 3,因为图像是 rgb
但是,当我尝试使用我的模型进行拟合时,我收到一条错误消息,指出预测为 [32,5] 并且标签形状为 [160]。在我看来,标签中的批次已经“崩溃”了。
这里有一些sn-ps:
BATCH_SIZE = 32
EPOCHS = 1
IMG_SIZE=(300, 300)
SEED = 1
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
directory='train/train_images/', label_mode='categorical', class_names=class_names, color_mode='rgb', batch_size=BATCH_SIZE, image_size=IMG_SIZE)
IMG_SHAPE = IMG_SIZE + (3,)
n_classes = len(train_dataset.class_names)
def build_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(input_shape=IMG_SHAPE, kernel_size=(5, 5), filters=32, activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=(3, 3)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Dense(units=n_classes, activation='softmax')
])
return model
model = build_model()
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
model.fit(train_dataset, epochs = EPOCHS, batch_size = BATCH_SIZE)
错误信息:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-19-86d96e744ef0> in <module>
----> 1 model.fit(train_dataset, epochs = EPOCHS, batch_size = BATCH_SIZE)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)
-> 1098 tmp_logs = train_function(iterator)
1099 if data_handler.should_sync:
1100 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
805 # In this case we have created variables on the first call, so we run the
806 # defunned version which is guaranteed to never create variables.
--> 807 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
808 elif self._stateful_fn is not None:
809 # Release the lock early so that multiple threads can perform the call
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2830
2831 @property
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager)
1846 resource_variable_ops.BaseResourceVariable))],
1847 captured_inputs=self.captured_inputs,
-> 1848 cancellation_manager=cancellation_manager)
1849
1850 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1922 # No tape is watching; skip to running the function.
1923 return self._build_call_outputs(self._inference_function.call(
-> 1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
1926 args,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
548 inputs=args,
549 attrs=attrs,
--> 550 ctx=ctx)
551 else:
552 outputs = execute.execute_with_cancellation(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [32,5] and labels shape [160]
[[node sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at <ipython-input-18-1904262c6a7b>:1) ]] [Op:__inference_train_function_928]
Function call stack:
train_function
【问题讨论】:
-
您的回溯中的代码与您发布的代码不对应
-
啊,是的,我将:model.fit(train_dataset, epochs = EPOCHS, batch_size = BATCH_SIZE) 改为: hist = model_cnn.fit(train_dataset, steps_per_epoch = 17117 // BATCH_SIZE, ----> 9 epochs = EPOCHS, batch_size = BATCH_SIZE) 因为我认为这可能是相关的:stackoverflow.com/questions/63049638/… 虽然我收到了完全相同的错误消息。
标签: python tensorflow keras