【问题标题】:CTC: blank must be in label rangeCTC:空白必须在标签范围内
【发布时间】:2019-12-09 05:08:22
【问题描述】:

总结

我正在为验证码识别添加字母,但添加字母后 pytorch 的 CTC 似乎无法正常工作。

我尝试过的

起初,我将 BLANK_LABEL 修改为 62,因为有 62 个标签(0-9,a-z,A-Z),但它给了我运行时错误 blank must be in label range。我还尝试了BLANK_LABEL=0,然后将 1~63 分配为非空白标签,但它输出 NaN 作为损失。

代码

这是我的代码当前版本的 colab 链接:here

以下只是代码的核心部分。

常量:

DATASET_PATH = "/home/ik1ne/Downloads/numbers"
MODEL_PATH = "/home/ik1ne/Downloads"

BATCH_SIZE = 50
TRAIN_BATCHES = 180
TEST_BATCHES = 20
TOTAL_BATCHES = TRAIN_BATCHES+TEST_BATCHES
TOTAL_DATASET = BATCH_SIZE*TOTAL_BATCHES

BLANK_LABEL = 63

数据集生成:

!pip install captcha
from captcha.image import ImageCaptcha

import itertools
import os
import random
import string

if not os.path.exists(DATASET_PATH):
    os.makedirs(DATASET_PATH)

characters = "0123456789"+string.ascii_lowercase + string.ascii_uppercase

while(len(list(Path(DATASET_PATH).glob('*'))) < TOTAL_BATCHES):
    captcha_str = "".join(random.choice(characters) for x in range(6))
    if captcha_str in list(Path(DATASET_PATH).glob('*')):
        continue
    ImageCaptcha().write(captcha_str, f"{DATASET_PATH}/{captcha_str}.png")

数据集:

def convert_strseq_to_numseq(s):
    for c in s:
        if c >= '0' and c <= '9':
            return int(c)
        elif c>='a' and c <='z':
            return ord(c)-ord('a')+10
        else:
            return ord(c)-ord('A')+36

class CaptchaDataset(Dataset):
    """CAPTCHA dataset."""

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_paths = list(Path(root_dir).glob('*'))
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        if self.transform:
            image = self.transform(image)

        label_sequence = [convert_strseq_to_numseq(c) for c in self.image_paths[index].stem]
        return (image, torch.tensor(label_sequence))
    
    def __len__(self):
        return len(self.image_paths)

型号:

class StackedLSTM(nn.Module):
    def __init__(self, input_size=60, output_size=11, hidden_size=512, num_layers=2):
        super(StackedLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(hidden_size, output_size)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        
    def forward(self, inputs, hidden):
        batch_size, seq_len, input_size = inputs.shape
        outputs, hidden = self.lstm(inputs, hidden)
        outputs = self.dropout(outputs)
        outputs = torch.stack([self.fc(outputs[i]) for i in range(width)])
        outputs = F.log_softmax(outputs, dim=2)
        return outputs, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data 
        return (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
                weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
    
net = StackedLSTM().to(device)

训练:

net.train()  # set network to training phase
    
epochs = 30

# for each pass of the training dataset
for epoch in range(epochs):
    train_loss, train_correct, train_total = 0, 0, 0
    
    h = net.init_hidden(BATCH_SIZE)
    
    # for each batch of training examples
    for batch_index, (inputs, targets) in enumerate(train_dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        h = tuple([each.data for each in h])
        
        BATCH_SIZE, channels, height, width = inputs.shape
        
        # reshape inputs: NxCxHxW -> WxNx(HxC)
        inputs = (inputs
                  .permute(3, 0, 2, 1)
                  .contiguous()
                  .view((width, BATCH_SIZE, -1)))
        
        optimizer.zero_grad()  # zero the parameter gradients
        outputs, h = net(inputs, h)  # forward pass

        # compare output with ground truth
        input_lengths = torch.IntTensor(BATCH_SIZE).fill_(width)
        target_lengths = torch.IntTensor([len(t) for t in targets])
        loss = criterion(outputs, targets, input_lengths, target_lengths)

        loss.backward()  # backpropagation
        nn.utils.clip_grad_norm_(net.parameters(), 10)  # clip gradients
        optimizer.step()  # update network weights
        
        # record statistics
        prob, max_index = torch.max(outputs, dim=2)
        train_loss += loss.item()
        train_total += len(targets)

        for i in range(BATCH_SIZE):
            raw_pred = list(max_index[:, i].cpu().numpy())
            pred = [c for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
            target = list(targets[i].cpu().numpy())
            if pred == target:
                train_correct += 1

        # print statistics every 10 batches
        if (batch_index + 1) % 10 == 0:
            print(f'Epoch {epoch + 1}/{epochs}, ' +
                  f'Batch {batch_index + 1}/{len(train_dataloader)}, ' +
                  f'Train Loss: {(train_loss/1):.5f}, ' +
                  f'Train Accuracy: {(train_correct/train_total):.5f}')
            
            train_loss, train_correct, train_total = 0, 0, 0

【问题讨论】:

    标签: python pytorch ctc


    【解决方案1】:

    blank的索引大于类的总数,等于number of chars + blank时会出现这个错误。更重要的是,索引从0开始,而不是1,所以如果你总共有62个字符,它们的索引应该是0-61,空白的索引应该是62而不是63 . (或者你可以设置空白为0,其他字符为1-62

    您还应该检查输出张量的形状,它应该具有形状[T, B, C],其中T 是时间步长,B 是批量大小,C 是类 num,请记住在班级编号中添加空白,否则您将遇到问题

    【讨论】:

      【解决方案2】:

      当它被发送到 CTC loss 时,网络形状很可能存在一些问题,但您应该将数据集提供给我们以查看网络的形状。它应该是 (T,N,C) ,其中 T=输入长度,N=批量大小,C=类数。据我了解,空白符号 id 应该在 0..C 范围内。此外,您应该在字母表中添加空白符号,例如“-”。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2021-01-08
        • 2018-02-22
        • 2023-03-16
        • 2020-06-24
        • 1970-01-01
        • 2020-09-11
        • 2023-01-21
        相关资源
        最近更新 更多