【问题标题】:RuntimeError: size mismatch m1: [a x b], m2: [c x d]RuntimeError: 尺寸不匹配 m1: [a x b], m2: [c x d]
【发布时间】:2019-05-18 14:30:50
【问题描述】:

任何人都可以帮助我吗?我得到以下错误。我使用谷歌 Colab。如何解决这个错误?

大小不匹配,m1:[64 x 100],m2:[784 x 128] 在 /pytorch/aten/src/TH/generic/THTensorMath.cpp:2070

下面的代码我正在尝试运行。

    import torch
    from torch import nn
    import torch.nn.functional as F
    from torchvision import datasets, transforms

    # Define a transform to normalize the data
    transform = 
    transforms.Compose([transforms.CenterCrop(10),transforms.ToTensor(),])
    # Download the load the training data
    trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, 
    train=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, 
    shuffle=True)

    # Build a feed-forward network
    model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 
    64),nn.ReLU(),nn.Linear(64, 10))

    # Define the loss
    criterion = nn.CrossEntropyLoss()

   # Get our data
   images, labels = next(iter(trainloader))
   # Faltten images
   images = images.view(images.shape[0], -1)

   # Forward pass, get our logits
   logits = model(images)
   # Calculate the loss with the logits and the labels
   loss = criterion(logits, labels)
   print(loss)

【问题讨论】:

    标签: python-3.x machine-learning image-processing computer-vision pytorch


    【解决方案1】:

    您只需要关心b=c,您就完成了:

    m1: [a x b], m2: [c x d]
    

    m1[a x b][batch size x in features]

    m2[c x d][in features x out features]

    【讨论】:

    • 我该如何解决这个错误 RuntimeError: size mismatch, m1: [4 x 2048], m2: [1568 x 10]?
    【解决方案2】:

    您的尺寸不匹配!
    您的第一层 model 需要 784 维输入(我假设您通过 28x28=784 获得此值,即 mnist 数字的大小)。
    但是,您的 trainset 应用 transforms.CenterCrop(10) - 也就是说,它从图像中心裁剪了一个 10x10 的区域,因此您的输入尺寸实际上是 100。

    总结:
    - 你的第一层:nn.Linear(784, 128) 需要一个 784 维的输入,输出一个 128 维的隐藏特征向量(每个输入)。因此,该层的权重矩阵为[784 x 128](错误消息中的“m2”)。
    - 您的输入被中心裁剪为 10x10 像素(总共 100-dim),并且您在每批中都有 batch_size=64 这样的图像,总共有 [64 x 100] 输入大小(错误消息中的“m1”)。
    - 您无法计算大小不匹配的矩阵之间的点积:100 != 784,因此 pytorch 会给您一个错误。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-07-25
      • 1970-01-01
      • 2021-05-13
      • 2012-01-28
      • 1970-01-01
      • 2020-06-11
      相关资源
      最近更新 更多