【发布时间】:2021-03-24 20:13:38
【问题描述】:
这是问题:
在定义模型之前,我们先定义字母表的大小。我们的字母表由小写英文字母和一个特殊字符组成,用于符号之间或单词前后的空格。对于这个作业的第一部分,我们不需要那个额外的字符。
我们的最终目标是学习转录任意长度的单词。然而,首先,我们预训练简单的卷积神经网络来识别单个字符。为了能够对一个字符和整个单词使用相同的模型,我们将设计模型以确保一个字符的输出大小(或当输入图像大小为 32x18 时)为 1x27,和 Kx27,只要输入图像更宽。这里的 K 将取决于网络的特定架构,并受步幅、池化等因素的影响。更正式一点,我们的模型?????? ,对于输入图像????给出输出能量 ????=????????(????) 。如果 ????∈ℝ32×18 ,则 ????∈ℝ1×27 。例如,如果 ????∈ℝ32×100,我们的模型可能会输出 ????∈ℝ10×27 ,其中 ????????对应于 ???? 中的特定窗口,例如从 ??????0,9????到 ????32,9????+18(同样,这取决于特定的架构)。
代码:
# constants for number of classes in total, and for the special extra character for empty space
ALPHABET_SIZE = 27, # Extra character for space inbetween
BETWEEN = 26
print(alphabet.shape) # RETURNS: torch.Size([32, 340])
我的 CNN 块:
from torch import nn
import torch.nn.functional as F
"""
Remember basics:
1. Bigger strides = less overlap
2. More filters = More features
Image shape = 32, 18
Alphabet shape = 32, 340
"""
class SimpleNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.cnn_block = torch.nn.Sequential(
nn.Conv2d(3, 32, 3),
nn.BatchNorm2d(32),
nn.Conv2d(32, 32, 3),
nn.BatchNorm2d(32),
nn.Conv2d(32, 32, 3),
nn.BatchNorm2d(32),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.MaxPool2d(2)
)
def forward(self, x):
x = self.cnn_block(x)
# after applying cnn_block, x.shape should be:
# batch_size, alphabet_size, 1, width
return x[:, :, 0, :].permute(0, 2, 1)
model = SimpleNet()
alphabet_energies = model(alphabet.view(1, 1, *alphabet.shape))
def plot_energies(ce):
fig=plt.figure(dpi=200)
ax = plt.axes()
im = ax.imshow(ce.cpu().T)
ax.set_xlabel('window locations →')
ax.set_ylabel('← classes')
ax.xaxis.set_label_position('top')
ax.set_xticks([])
ax.set_yticks([])
cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
plt.colorbar(im, cax=cax)
plot_energies(alphabet_energies[0].detach())
我在 alphabet_energies = model(alphabet.view(1, 1, *alphabet.shape)) 处得到标题中的错误
任何帮助将不胜感激。
【问题讨论】:
标签: pytorch conv-neural-network