【问题标题】:How does the dimensions work when training a keras model?训练 keras 模型时,维度如何工作?
【发布时间】:2019-07-19 01:24:20
【问题描述】:

获取:

    assert q_values.shape == (len(state_batch), self.nb_actions)
AssertionError
q_values.shape <class 'tuple'>: (1, 1, 10)
(len(state_batch), self.nb_actions) <class 'tuple'>: (1, 10)

来自sarsa代理的keras-rl库:

rl.agents.sarsa.SARSAAgent#compute_batch_q_values

    batch = self.process_state_batch(state_batch)
    q_values = self.model.predict_on_batch(batch)
    assert q_values.shape == (len(state_batch), self.nb_actions)

这是我的代码:

class MyEnv(Env):

    def __init__(self):
        self._reset()

    def _reset(self) -> None:
        self.i = 0

    def _get_obs(self) -> List[float]:
        return [1] * 20

    def reset(self) -> List[float]:
        self._reset()
        return self._get_obs()



    model = Sequential()
    model.add(Dense(units=20, activation='relu', input_shape=(1, 20)))
    model.add(Dense(units=10, activation='softmax'))
    logger.info(model.summary())

    policy = BoltzmannQPolicy()
    agent = SARSAAgent(model=model, nb_actions=10, policy=policy)

    optimizer = Adam(lr=1e-3)
    agent.compile(optimizer, metrics=['mae'])

    env = MyEnv()
    agent.fit(env, 1, verbose=2, visualize=True)

想知道是否有人可以向我解释应该如何设置维度以及它如何与库一起使用?我正在输入一个包含 20 个输入的列表,并希望输出 10 个。

【问题讨论】:

    标签: python keras keras-rl


    【解决方案1】:

    此特定错误是由您的输入形状为 (1, 20) 引起的。如果您使用 (20,) 的输入形状,错误就会消失。

    换句话说,SARSAAgent 需要一个输出二维张量(batch_size,nb_actions)的模型。并且您的模型正在输出 (batch_size, 1, 10) 的形状。您可以减少模型输入中的尺寸或展平输出。

    【讨论】:

    • Keras 然后向ValueError: Error when checking input: expected dense_1_input to have 2 dimensions, but got array with shape (1, 1, 20) 投诉。我会看看扁平化的
    • 当您将输入形状更改为 (20,) 时,您需要将您提供给model.fit 的 numpy 数组更改为相应的形状。使用 (1, 20,) 的输入形状比标准 (20,) 增加了更多的复杂性。即带有 1 的额外维度没有提供额外的价值,并首先造成了这个问题。
    【解决方案2】:

    自定义环境

    先搭建一个简单的玩具环境

    1. 这是一个一维迷宫:[1,1,0,1,1,0,1,1,0]
    2. 1:踏入这块迷宫,奖励1
    3. 0:踏入这块迷宫会死,奖励0
    4. 允许的动作0:移动到下一个迷宫块,1:跳过下一个块,即跳过下一个并移动到下一个迷宫块旁边的那个

    要在健身房中实现我们的环境,我们需要实现 2 个方法

    • step:接受一个动作并执行 step 并返回 step take 后的状态,reward 和一个表示游戏是否结束的布尔值
    • reset:重置游戏并返回当前状态(初始状态)

    环境代码

    class FooEnv(gym.Env):
        def __init__(self):
            self.maze = [1,1,0,1,1,0,1,1,0]
            self.curr_state = 0
            self.action_space = spaces.Discrete(2)
            self.observation_space = spaces.Discrete(1)
    
        def step(self, action):        
            if action == 0:
                self.curr_state += 1
            if action == 1:
                self.curr_state += 2
    
            if self.curr_state >= len(self.maze):
                reward = 0.
                done = True
            else:
                if self.maze[self.curr_state] == 0:
                    reward = 0.
                    done = True
                else:
                    reward = 1.
                    done = False
            return np.array(self.curr_state), reward, done, {}
    
        def reset(self):
            self.curr_state = 0
            return np.array(self.curr_state)
    

    神经网络

    现在给定当前状态,我们希望 NN 预测要采取的行动。

    • NN 将采用当前状态,这是一个代表我们所在的当前迷宫块的单个数字作为输入
    • NN 将返回两个可能的操作之一0 或 `1

    NN 代码

    model = Sequential()
    model.add(Dense(units=16, activation='relu', input_shape=(1,)))
    model.add(Dense(units=8, activation='relu'))
    model.add(Dense(units=2, activation='softmax'))
    

    把它放在一起

    policy = BoltzmannQPolicy()
    agent = SARSAAgent(model=model, nb_actions=2, policy=policy)
    
    optimizer = Adam(lr=1e-3)
    agent.compile(optimizer, metrics=['acc'])
    
    env = FooEnv()
    agent.fit(env, 10000, verbose=1, visualize=False)
    # Test the trained agent using
    # agent.test(env, nb_episodes=5, visualize=False)
    

    输出

    Training for 10000 steps ...
    Interval 1 (0 steps performed)
    10000/10000 [==============================] - 54s 5ms/step - reward: 0.6128
    done, took 53.519 seconds
    

    如果您的环境是网格 (2D),如果大小为 n X m,则 NN 的输入大小将为 (n,m),如下所示,并在传递到密集层之前将其展平

    model.add(Flatten(input_shape=(n,m))
    

    keras-rl docs查看这个例子

    【讨论】:

      猜你喜欢
      • 2021-05-14
      • 1970-01-01
      • 2020-04-05
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-09-24
      • 1970-01-01
      相关资源
      最近更新 更多