【问题标题】:Why is DQNAgent.fit adding extra dimensions to my input data?为什么 DQNAgent.fit 会向我的输入数据添加额外的维度?
【发布时间】:2021-06-30 20:05:07
【问题描述】:

我正在使用 Keras 的深度 q 学习代理之一:DQNAgent。当我将环境传递给 DQNAgent.fit 时,我收到以下错误:

**3 dqn.fit(env, nb_steps=50000, visualize=False, verbose=1)**

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training_utils_v1.py 在standardize_input_data(数据,名称,形状,check_batch_axis, 异常前缀)

655                            ': expected ' + names[i] + ' to have ' +
656                            str(len(shape)) + ' dimensions, but got array '
**657                            'with shape ' + str(data_shape))**
658         if not check_batch_axis:
659           data_shape = data_shape[1:]

ValueError:检查输入时出错:预期的 dense_18_input 有 2 维,但得到了形状为 (1, 1, 65) 的数组

我的环境的状态和空间定义如下:

self.state = np.zeros(65, dtype=int)
self.action_space = spaces.Tuple((spaces.Discrete(64), spaces.Discrete(64)))
self.observation_space = spaces.Box(low=0, high=16, shape=(65,), dtype=np.int)

我正在使用以下模型:

states = env.observation_space.shape
actions = 64**2
def build_model(states, actions):
    model = Sequential()    
    model.add(Dense(100, activation='relu', input_shape=states))
    model.add(Dense(200, activation='relu'))
    model.add(Dense(actions, activation='linear'))
    return model

我的环境的状态向量的形状为 (65,),但 fit 方法将其增强为 (1, 1, 65)——导致形状不匹配。需要明确的是,self.state 作为来自环境的观察返回。有谁知道为什么会这样?

【问题讨论】:

    标签: python tensorflow dqn keras-rl


    【解决方案1】:

    首先,当您指定模型的输入时,Keras 会添加另一个维度,因为它需要一个 Batch。例如:

    input_shape=(65,) --> (None, 65)
    

    因此,当您将单个观察值转发到模型中时,Keras 会假定 batch_size=1。因此,您的输入大小变为:

    (None, 65) --> (1,65)
    

    现在,为了获得形状为(1,1,65) 的输入,这意味着您喂食并观察大小为batch_size + (1,65) = (1,1,65)。这意味着由于某种原因,您的观察结果在实际输入网络之前被转置(重塑)。

    您在将观察形状输入网络之前检查过它吗?

    【讨论】:

    • 提出此问题的用户提出了以下问题:您是否知道可能导致此问题的原因以及我如何能够解决它?我有几乎相同的问题 - 唯一的区别是我初始化观察空间如下:self.observation_space = Box(low=np.zeros(85), high=np.ones(85), dtype=np.uint8)。任何帮助将不胜感激。
    • 是的,检查代理从环境接收的形状。你可以这样做:observation = env.reset() 然后print(observation.shape, observation.dtype)
    最近更新 更多