【问题标题】:Can't print model summary using PyTorch?无法使用 PyTorch 打印模型摘要?
【发布时间】:2022-11-20 23:01:02
【问题描述】:

你好,我正在为 cartpole 上的强化学习构建一个 DQN 模型,并想像 keras model.summary() 函数一样打印我的模型摘要

这是我的模型类。

class DQN():
    ''' Deep Q Neural Network class. '''
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=0.05):
            super(DQN, self).__init__()
            self.criterion = torch.nn.MSELoss()
            self.model = torch.nn.Sequential(
                            torch.nn.Linear(state_dim, hidden_dim),
                            torch.nn.ReLU(),
                            torch.nn.Linear(hidden_dim, hidden_dim*2),
                            torch.nn.ReLU(),
                            torch.nn.Linear(hidden_dim*2, action_dim)
                    )
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr)



    def update(self, state, y):
        """Update the weights of the network given a training sample. """
        y_pred = self.model(torch.Tensor(state))
        loss = self.criterion(y_pred, Variable(torch.Tensor(y)))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    def predict(self, state):
        """ Compute Q values for all actions using the DQL. """
        with torch.no_grad():
            return self.model(torch.Tensor(state))

这是传递了参数的模型实例。

# Number of states = 4
n_state = env.observation_space.shape[0]
# Number of actions = 2
n_action = env.action_space.n
# Number of episodes
episodes = 150
# Number of hidden nodes in the DQN
n_hidden = 50
# Learning rate
lr = 0.001


simple_dqn = DQN(n_state, n_action, n_hidden, lr)

我尝试使用 torchinfo summary

from torchinfo import summary
simple_dqn = DQN(n_state, n_action, n_hidden, lr)
summary(simple_dqn, input_size=(4, 2, 50))

但我收到以下错误

NotImplementedError                       Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    286             if isinstance(x, (list, tuple)):
--> 287                 _ = model.to(device)(*x, **kwargs)
    288             elif isinstance(x, dict):

4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1147 
-> 1148         result = forward_call(*input, **kwargs)
   1149         if _global_forward_hooks or self._forward_hooks:

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
    200     """
--> 201     raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
    202 

NotImplementedError: Module [DQN] is missing the required "forward" function

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-24-ee921f7e5cb5> in <module>
      1 from torchinfo import summary
      2 simple_dqn = DQN(n_state, n_action, n_hidden, lr)
----> 3 summary(simple_dqn, input_size=(4, 2, 50))

/usr/local/lib/python3.7/dist-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
    216     )
    217     summary_list = forward_pass(
--> 218         model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
    219     )
    220     formatting = FormattingOptions(depth, verbose, columns, col_width, rows)

/usr/local/lib/python3.7/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    297             "Failed to run torchinfo. See above stack traces for more details. "
    298             f"Executed layers up to: {executed_layers}"
--> 299         ) from e
    300     finally:
    301         if hooks:

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

任何帮助表示赞赏。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    如果您查看堆栈跟踪,您可以看到它在开始时抛出此错误。

    /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
        200     """
    --> 201     raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required "forward" function")
        202 
    
    NotImplementedError: Module [DQN] is missing the required "forward" function
    

    这是您应该查看的主要错误。这就是您缺少模型应具有的前向功能。这是您将如何实施它的示例。

    def forward(self, x):
        x = self.pool(torch.nn.relu(self.conv1(x)))
        x = self.pool(torch.nn.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.nn.relu(self.fc1(x))
        x = torch.nn.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

    【讨论】:

    • 谢谢你的回答。使用您的前向函数代码会给出 AttributeError 的错误:'DQN' object has no attribute 'pool' relation to self.pool
    猜你喜欢
    • 2023-01-05
    • 1970-01-01
    • 2016-01-04
    • 1970-01-01
    • 2021-10-05
    • 2021-02-15
    • 2019-11-06
    • 1970-01-01
    • 2016-11-21
    相关资源
    最近更新 更多