【问题标题】:Linear regression with pytorch使用 pytorch 进行线性回归
【发布时间】:2019-01-06 06:34:36
【问题描述】:

我尝试在 ForestFires 数据集上运行线性回归。 数据集在 Kaggle 上可用,我的尝试要点在这里: https://gist.github.com/Chandrak1907/747b1a6045bb64898d5f9140f4cf9a37

我面临两个问题:

  1. 预测输出的形状为 32x1,目标数据的形状为 32。

输入和目标形状不匹配:输入 [32 x 1],目标 [32]¶

使用视图 I 重塑预测张量。

y_pred = y_pred.view(inputs.shape[0])

为什么预测张量和实际张量的形状不匹配?

  1. pytorch 中的 SGD 永远不会收敛。我尝试使用
  2. 手动计算 MSE

打印(torch.mean((y_pred - 标签)**2))

此值不匹配

损失 = 标准(y_pred,标签)

有人可以指出我的代码中的错误在哪里吗?

谢谢。

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    问题 1

    这是 Pytorch 文档中关于 MSELoss 的参考:https://pytorch.org/docs/stable/nn.html#torch.nn.MSELoss

    Shape:
     - Input: (N,∗) where * means, any number of additional dimensions
     - Target: (N,∗), same shape as the input
    

    因此,您需要扩展标签的尺寸:(32) -> (32,1),使用:torch.unsqueeze(labels, 1)labels.view(-1,1)

    https://pytorch.org/docs/stable/torch.html#torch.unsqueeze

    torch.unsqueeze(input, dim, out=None) → 张量

    返回一个新张量,尺寸为一,插入指定位置。

    返回的张量与该张量共享相同的基础数据。

    问题 2

    查看您的代码后,我意识到您已将 size_average 参数添加到 MSELoss:

    criterion = torch.nn.MSELoss(size_average=False)
    

    size_average (bool, optional) – 已弃用(见减少)。默认情况下,损失是批次中每个损失元素的平均值。请注意,对于某些损失,每个样本有多个元素。如果字段 size_average 设置为 False,则将每个 minibatch 的损失相加。当 reduce 为 False 时忽略。默认值:真

    这就是为什么 2 个计算值不匹配的原因。这是示例代码:

    import torch
    import torch.nn as nn
    
    loss1 = nn.MSELoss()
    loss2 = nn.MSELoss(size_average=False)
    inputs = torch.randn(32, 1, requires_grad=True)
    targets = torch.randn(32, 1)
    
    output1 = loss1(inputs, targets)
    output2 = loss2(inputs, targets)
    output3 = torch.mean((inputs - targets) ** 2)
    
    print(output1)  # tensor(1.0907)
    print(output2)  # tensor(34.9021)
    print(output3)  # tensor(1.0907)
    

    【讨论】:

    • 太棒了。这解决了问题 1。但是,使用 torch.mean((y_pred - labels)**2) 计算的损失与 loss = criteria(y_pred,labels) 不匹配。你能告诉我,它们有什么不同吗?
    • @Chandra 我已经意识到您的代码使用了size_average 参数。我已经更新了我的详细答案
    • 是的。谢谢。
    猜你喜欢
    • 2022-01-21
    • 2017-09-19
    • 2021-06-17
    • 2018-07-31
    • 2021-12-14
    • 2021-01-11
    • 2018-12-14
    • 2018-02-03
    • 2018-07-23
    相关资源
    最近更新 更多