1.仓库地址

https://github.com/meijieru/crnn.pytorch
原版用lua实现的:https://github.com/bgshih/crnn
需要用到的warp_ctc_pytorch: https://github.com/SeanNaren/warp-ctc

2.环境安装

普通的环境都可以吧,我是cuda10.0,torch1.2.0 python3.6. 其他环境也应该可以。
然后库缺少什么就安装什么 pip install ***

warp-CTC需要编译

git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding
python setup.py install

我就是这么没有报错就ok
测试是否安装成功就进入python
import warpctc_pytorch
没有报错就说明成功

3.数据准备,lmdb制作

crnn pytorch 训练、测试
需要这么放置,图片和文本放在一个文件夹,文本名和图片名字一样,文本里面内容是图片上文字。
运行https://github.com/wuzuowuyou/crnn_pytorch/blob/master/myfile/create_lmdb.py脚本
这里注意需要python2运行。我用Python3运行各种报错什么编码问题,用py2跑一点报错都没有,python2也需要装lmdb,(pip2 install lmdb)
跑成功会自动生成这两个东东
./lmdb/data.mdb
./lmdb/lock.mdb
把lmdb文件夹放在data目录下面。

4. 训练

python train.py --adadelta --trainRoot ./data/lmdb/ --valRoot ./data/lmdb/ --cuda

这里注意一下,如果有大小写,需要改下字典表
train.py line32
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')

5.报错解决

各种报错啊
5.1 trainRoot,valRoot需要改下大小写
5.2 TypeError: Won't implicitly convert Unicode to bytes; use .encode()
按照错误提示加上encode
txn.get('num-samples'.encode())
label_byte = txn.get(label_key.encode())
imgbuf = txn.get(img_key.encode())
5.3
text, _ = self.encode(text)
File "/home/crnn.pytorch/utils.py", line 45, in encode
for char in text
File "/home/crnn.pytorch/utils.py", line 45, in
for char in text
KeyError: 'b'
解决方案:
dataset.py line 61
label = str(txn.get(label_key)) ->
label_byte=txn.get(label_key.encode())
label = label_byte.decode()

5.4 raise ValueError('sampler option is mutually exclusive with '
ValueError: sampler option is mutually exclusive with shuffle
大意就是sampler和shuffle互斥
我加了 and 0 不用sample
if not opt.random_sample and 0:

5.5 在验证的时候还报错,
Start val
Traceback (most recent call last):
File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 219, in
val(crnn, test_dataset, criterion)
File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 168, in val
preds = preds.squeeze(2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
我不验证,加and 0:
if i % opt.valInterval == 0 and 0:
val(crnn, test_dataset, criterion)

错误解决了,然后就可以训练,打印如下:

  (relu6): ReLU(inplace=True)
  )
  (rnn): Sequential(
    (0): BidirectionalLSTM(
      (rnn): LSTM(512, 256, bidirectional=True)
      (embedding): Linear(in_features=512, out_features=256, bias=True)
    )
    (1): BidirectionalLSTM(
      (rnn): LSTM(256, 256, bidirectional=True)
      (embedding): Linear(in_features=512, out_features=63, bias=True)
    )
  )
)
[0/100000000][1/9] Loss: 8.430408
[0/100000000][2/9] Loss: 20.137066
[0/100000000][3/9] Loss: 25.239346
[0/100000000][4/9] Loss: 21.249365
[0/100000000][5/9] Loss: 20.604660
[0/100000000][6/9] Loss: 14.782236

6.测试 demo.py

需要改下这里,和训练的时候一致
model = crnn.CRNN(32, 1, 37, 256)

报错
File "/data_2/project_2021/crnn/crnn.pytorch-master/demo_show.py", line 42, in
model.load_state_dict(torch.load(model_path))
File "/data_1/Yang/software_install/Anaconda1105/envs/CenterNet_1.0_3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CRNN:
Missing key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "rnn.0.rnn.weight_ih_l0", "rnn.0.rnn.weight_hh_l0", "rnn.0.rnn.bias_ih_l0", "rnn.0.rnn.bias_hh_l0", "rnn.0.rnn.weight_ih_l0_reverse", "rnn.0.rnn.weight_hh_l0_reverse", "rnn.0.rnn.bias_ih_l0_reverse", "rnn.0.rnn.bias_hh_l0_reverse", "rnn.0.embedding.weight", "rnn.0.embedding.bias", "rnn.1.rnn.weight_ih_l0", "rnn.1.rnn.weight_hh_l0", "rnn.1.rnn.bias_ih_l0", "rnn.1.rnn.bias_hh_l0", "rnn.1.rnn.weight_ih_l0_reverse", "rnn.1.rnn.weight_hh_l0_reverse", "rnn.1.rnn.bias_ih_l0_reverse", "rnn.1.rnn.bias_hh_l0_reverse", "rnn.1.embedding.weight", "rnn.1.embedding.bias".
Unexpected key(s) in state_dict: "module.cnn.conv0.weight", "module.cnn.conv0.bias", "module.cnn.conv1.weight", "module.cnn.conv1.bias", "module.cnn.conv2.weight", "module.cnn.conv2.bias", "module.cnn.batchnorm2.weight", "module.cnn.batchnorm2.bias", "module.cnn.batchnorm2.running_mean", "module.cnn.batchnorm2.running_var", "module.cnn.batchnorm2.num_batches_tracked", "module.cnn.conv3.weight", "module.cnn.conv3.bias", "module.cnn.conv4.weight", "module.cnn.conv4.bias", "module.cnn.batchnorm4.weight", "module.cnn.batchnorm4.bias", "module.cnn.batchnorm4.running_mean", "module.cnn.batchnorm4.running_var", "module.cnn.batchnorm4.num_batches_tracked", "module.cnn.conv5.weight", "module.cnn.conv5.bias", "module.cnn.conv6.weight", "module.cnn.conv6.bias", "module.cnn.batchnorm6.weight", "module.cnn.batchnorm6.bias", "module.cnn.batchnorm6.running_mean", "module.cnn.batchnorm6.running_var", "module.cnn.batchnorm6.num_batches_tracked", "module.rnn.0.rnn.weight_ih_l0", "module.rnn.0.rnn.weight_hh_l0", "module.rnn.0.rnn.bias_ih_l0", "module.rnn.0.rnn.bias_hh_l0", "module.rnn.0.rnn.weight_ih_l0_reverse", "module.rnn.0.rnn.weight_hh_l0_reverse", "module.rnn.0.rnn.bias_ih_l0_reverse", "module.rnn.0.rnn.bias_hh_l0_reverse", "module.rnn.0.embedding.weight", "module.rnn.0.embedding.bias", "module.rnn.1.rnn.weight_ih_l0", "module.rnn.1.rnn.weight_hh_l0", "module.rnn.1.rnn.bias_ih_l0", "module.rnn.1.rnn.bias_hh_l0", "module.rnn.1.rnn.weight_ih_l0_reverse", "module.rnn.1.rnn.weight_hh_l0_reverse", "module.rnn.1.rnn.bias_ih_l0_reverse", "module.rnn.1.rnn.bias_hh_l0_reverse", "module.rnn.1.embedding.weight", "module.rnn.1.embedding.bias".

Process finished with exit code 1

原因在于我们保存的pth权重名字多了module.去掉就好。
需要改成如下:

nclass = len(alphabet) + 1

model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
    model = model.cuda()

#
# for m in model.state_dict().keys():
#      print("==:: ", m)

load_model_ = torch.load(model_path)
# for k, v in load_model_.items():
#     print(k,"  ::shape",v.shape)

state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():
    name = k[7:] # remove `module.`
    state_dict_rename[name] = v


print('loading pretrained model from %s' % model_path)
model.load_state_dict(state_dict_rename)

然后就可以测试了.
改动太多了,我把改好的代码上传git,有需要的下载。其中,放了10张测试图片和label,可以完成转lmdb。
https://github.com/wuzuowuyou/crnn_pytorch

分类:

技术点:

相关文章: