【发布时间】:2020-05-24 18:23:06
【问题描述】:
我正在测试我的 CNN 模型,但不断收到错误“AttributeError: 'numpy.ndarray' object has no attribute 'relu'”。
我的数据集是通过以下代码提取的:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
class MyDataset(Dataset):
def __init__(self, data, target, transform=None):
self.data = torch.from_numpy(data).float()
self.target = torch.from_numpy(target).long()
self.transform = transform
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.data)
numpy_data = np.random.randn(100,3,224,224) # 10 samples, image size = 224 x 224 x 3
numpy_target = np.random.randint(0,5,size=(100))
dataset = MyDataset(numpy_data, numpy_target)
我的模型很简单,如下:
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 36, kernel_size = 100)
def forward(self, t):
t = self.conv1(t)
print(t.shape)
print(type(t))
t = F.relu(t)
print(t.shape)
return t
我使用以下测试模型:
sample, target = next(iter(dataset))
network=Network()
pred = network(sample.unsqueeze(0))
我得到以下结果和错误:
torch.Size([1, 6, 125, 125])
<class 'torch.Tensor'>
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-40-37e58cfe971f> in <module>
----> 1 pred = network(sample.unsqueeze(0))
C:\Miniconda\envs\py37_default\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
<ipython-input-30-0d0592ad705d> in forward(self, t)
22 print(t.shape)
23 print(type(t))
---> 24 t = F.relu(t)
25 print(t.shape)
26 #t = F.max_pool2d(t, kernel_size=2, stride=2)
AttributeError: 'numpy.ndarray' object has no attribute 'relu'
我不知道为什么 type(t) 输出为 ,为什么错误显示它是 numpy.ndarray?
【问题讨论】:
-
你在哪里定义
F?F是 NumPy 数组,而不是t。 -
是的,我定义了 F,忘记在这里包含,刚刚添加。但还是同样的错误
标签: python tensorflow pytorch