【问题标题】:RuntimeError: expected scalar type Float but found Double (LSTM classifier)RuntimeError:预期的标量类型 Float 但发现 Double(LSTM 分类器)
【发布时间】:2021-09-15 22:13:33
【问题描述】:

我正在训练我的 LSTM 分类器。

epoch_num = 30

train_log = []
test_log = []
set_seed(111)
for epoch in range(1, epoch_num+1):

running_loss = 0    
train_loss = []
lstm_classifier.train()
for (inputs, labels) in tqdm(train_loader, desc='Training epoch ' + str(epoch), leave=False):        
    inputs, labels = inputs.to(device), labels.to(device)        
    optimizer.zero_grad()
    outputs = lstm_classifier(inputs)   
    loss = criterion(outputs, labels)
    loss.backward()                
    optimizer.step()        
    train_loss.append(loss.item())
train_log.append(np.mean(train_loss))

running_loss = 0
test_loss = []
lstm_classifier.eval()
with torch.no_grad():                
    for (inputs, labels) in tqdm(test_loader, desc='Test', leave=False):         
        inputs, labels = inputs.to(device), labels.to(device)        
        outputs = lstm_classifier(inputs)                       
        loss = criterion(outputs, labels)            
        test_loss.append(loss.item())
test_log.append(np.mean(test_loss))    
plt.plot(range(1, epoch+1), train_log, color='C0')
plt.plot(range(1, epoch+1), test_log, color='C1')
display.clear_output(wait=True)
display.display(plt.gcf())

错误是:

RuntimeError Traceback(最近一次调用最后一次) 在 ()

     23         print((labels.dtype))
     24         print(outputs[:,0].dtype)
---> 25         loss = criterion(outputs, labels)
     26         loss.backward()
     27         optimizer.step()

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826 

RuntimeError: 预期的标量类型 Float 但发现 Double

如何解决?

【问题讨论】:

    标签: python tensorflow pytorch lstm


    【解决方案1】:

    RuntimeError: 预期的标量类型 Float 但发现 Double

    loss = criterion(outputs, labels) 行的错误非常明显,因为它要求您的数据类型是 float 而不是 double,但它没有明确说明是 outputs 还是 label 正在创建它。

    我的猜测是因为标签。尝试通过 labels.float() 将其转换为浮点数

    【讨论】:

    • 奇数。你能提供输出和标签的样本来复制这个错误吗?
    猜你喜欢
    • 2021-07-10
    • 1970-01-01
    • 2020-06-11
    • 2023-03-21
    • 2020-05-31
    • 2021-03-05
    • 1970-01-01
    • 2020-06-15
    • 2019-11-06
    相关资源
    最近更新 更多