【问题标题】:Error when checking input: expected dense_203_input to have shape (1202,) but got array with shape (1,)检查输入时出错:预期 dense_203_input 的形状为 (1202,) 但得到的数组的形状为 (1,)
【发布时间】:2020-06-30 19:11:54
【问题描述】:

我制作了一个非常简单的神经网络,旨在进行强化学习。但是,我无法预测任何事情,因为我在尝试预测时遇到了错误。

有问题的错误:

检查输入时出错:预期dense_203_input 的形状为(1202,),但得到的数组的形状为(1,)

有问题的模型:

 def _build_compile_model(self):
    model = Sequential()
    model.add(Dense(300, activation='relu', input_dim=1202))
    model.add(Dense(300, activation='relu'))
    model.add(Dense(200, activation='relu'))
    model.add(Dense(self._action_size, activation='softmax'))

    model.compile(loss='mse', optimizer=self._optimizer)
    return model

调用 model.predict(state) 时发生错误,其中 state 是一个形状为 (1202, 1) 的数组。

完整的错误信息是

ValueError                                Traceback (most recent call last)
<ipython-input-148-06b7a01facef> in <module>
     18     new_state, reward = env.step(action, new_demand_a, new_demand_b) # Take action, get new state and reward
     19     new_state = np.reshape(new_state, [1202, -1])
---> 20     agent.update(old_state, new_state, action, reward) # Let the agent update internal
     21     average_reward.append(reward) # Keep score
     22     if i % 100 == 0 and i != 0: # Print out metadata every 100th iteration

<ipython-input-145-142ae54ce43f> in update(self, old_state, new_state, action, reward)
     49     def update(self, old_state, new_state, action, reward):
     50         print(old_state.shape)
---> 51         target = self.q_network.predict(old_state)
     52         t = self.target_network.predict(new_state)
     53         target[0][action] = reward + self.gamma * np.amax(t)

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1011         max_queue_size=max_queue_size,
   1012         workers=workers,
-> 1013         use_multiprocessing=use_multiprocessing)
   1014 
   1015   def reset_metrics(self):

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in predict(self, model, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
    496         model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose,
    497         steps=steps, callbacks=callbacks, max_queue_size=max_queue_size,
--> 498         workers=workers, use_multiprocessing=use_multiprocessing, **kwargs)
    499 
    500 

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _model_iteration(self, model, mode, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
    424           max_queue_size=max_queue_size,
    425           workers=workers,
--> 426           use_multiprocessing=use_multiprocessing)
    427       total_samples = _get_total_number_of_samples(adapter)
    428       use_sample = total_samples is not None

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    644     standardize_function = None
    645     x, y, sample_weights = standardize(
--> 646         x, y, sample_weight=sample_weights)
    647   elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
    648     standardize_function = standardize

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2381         is_dataset=is_dataset,
   2382         class_weight=class_weight,
-> 2383         batch_size=batch_size)
   2384 
   2385   def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
   2408           feed_input_shapes,
   2409           check_batch_axis=False,  # Don't enforce the batch size.
-> 2410           exception_prefix='input')
   2411 
   2412     # Get typespecs for the input data and sanitize it if necessary.

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    580                              ': expected ' + names[i] + ' to have shape ' +
    581                              str(shape) + ' but got array with shape ' +
--> 582                              str(data_shape))
    583   return data
    584 

ValueError: Error when checking input: expected dense_211_input to have shape (1202,) but got array with shape (1,)

【问题讨论】:

    标签: tensorflow keras


    【解决方案1】:

    在模型上输入输入时有两种方法:

    第一个选项:使用 input_shape

    model.add(Dense(300, activation='relu', input_shape=(1202,1)))
    

    这里的输入形状是 2D,但您应该为网络提供 3D 输入(Rank 3),因为您需要包含batch_size

    示例输入:

    state = np.array(np.ones((BATCH_SIZE,1202,1)))
    print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input
    

    第二个选项:使用 input_dim

    model_dim.add(Dense(300, activation='relu', input_dim=1202))
    

    这里的输入形状是 1D,但您应该为网络提供 2D 输入(Rank 2),因为您需要包含batch_size

    示例输入:

    state = np.array(np.ones((1,1202,)))
    print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input
    

    【讨论】:

      猜你喜欢
      • 2020-09-04
      • 2019-10-22
      • 1970-01-01
      • 2020-04-11
      • 2020-05-17
      • 2020-11-22
      • 1970-01-01
      • 1970-01-01
      • 2021-01-24
      相关资源
      最近更新 更多