【问题标题】:Result of auto-encoder dimensions are incorrect自动编码器尺寸的结果不正确
【发布时间】:2019-07-01 06:16:24
【问题描述】:

使用下面的代码,我正在尝试将 mnist 中的图像编码为较低维度的表示:

import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import pyplot as plt
from sklearn import metrics
import datetime
from sklearn.preprocessing import MultiLabelBinarizer
import seaborn as sns
sns.set_style("darkgrid")
from ast import literal_eval
import numpy as np
from sklearn.preprocessing import scale
import seaborn as sns
sns.set_style("darkgrid")
import torch
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable

%matplotlib inline

low_dim_rep = 32
epochs = 2

cuda = torch.cuda.is_available() # True if cuda is available, False otherwise
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
print('Training on %s' % ('GPU' if cuda else 'CPU'))

# Loading the MNIST data set
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))])
mnist = torchvision.datasets.MNIST(root='../data/', train=True, transform=transform, download=True)

# Loader to feed the data batch by batch during training.
batch = 100
data_loader = torch.utils.data.DataLoader(mnist, batch_size=batch, shuffle=True)


encoder = nn.Sequential(
                # Encoder
                nn.Linear(28 * 28, 64),
                nn.PReLU(64),
                nn.BatchNorm1d(64),

                # Low-dimensional representation
                nn.Linear(64, low_dim_rep),
                nn.PReLU(low_dim_rep),
                nn.BatchNorm1d(low_dim_rep))

decoder = nn.Sequential(
                # Decoder
                nn.Linear(low_dim_rep, 64),
                nn.PReLU(64),
                nn.BatchNorm1d(64),
                nn.Linear(64, 28 * 28))

autoencoder = nn.Sequential(encoder, decoder)

encoder = encoder.type(FloatTensor)
decoder = decoder.type(FloatTensor)
autoencoder = autoencoder.type(FloatTensor)

optimizer = torch.optim.Adam(params=autoencoder.parameters(), lr=0.00001)


data_size = int(mnist.train_labels.size()[0])

print('data_size' , data_size)
for i in range(epochs):
    for j, (images, _) in enumerate(data_loader):
        images = images.view(images.size(0), -1) # from (batch 1, 28, 28) to (batch, 28, 28)
        images = Variable(images).type(FloatTensor)

        autoencoder.zero_grad()
        reconstructions = autoencoder(images)
        loss = torch.dist(images, reconstructions)
        loss.backward()
        optimizer.step()
    print('Epoch %i/%i loss %.2f' % (i + 1, epochs, loss.data[0]))

print('Optimization finished.')

# Get the encoded images here
encoded_images = []
for j, (images, _) in enumerate(data_loader):
    images = images.view(images.size(0), -1) 
    images = Variable(images).type(FloatTensor)

    encoded_images.append(encoder(images))

完成此代码后

len(encoded_images) 是 600,而我希望长度与 mnist 中的图像数量相匹配:len(mnist) - 60'000。

如何将图像编码为 32 (low_dim_rep = 32) 的低维表示?网络参数定义不正确?

【问题讨论】:

    标签: deep-learning pytorch autoencoder


    【解决方案1】:

    您在mnist 和您的batch = 100 中有60000 图像。这就是为什么你的len(encoded_images)=600 因为你在生成编码图像时会进行60000/100=600 迭代。您最终会得到一个600 元素列表,其中每个元素的形状为[100, 32]。您可以执行以下操作

    encoded_images = torch.zeros(len(mnist), 32)
    for j, (images, _) in enumerate(data_loader):
        images = images.view(images.size(0), -1) 
        images = Variable(images).type(FloatTensor)
        encoded_images[j * batch : (j+1) * batch] = encoder(images)
    

    【讨论】:

    • 谢谢,但是使用上面的代码会返回错误:---------------------------------- ----------------------------------------- RuntimeError Traceback(最近一次调用最后一次) 在 () 3 个图像 = images.view(images.size(0), -1) 4 个图像 = Variable(images).type(FloatTensor) ----> 5 个编码图像[j * batch : (j+1) * batch] = encoder(images) RuntimeError: 张量 (32) 的扩展大小必须与非单维 1 处的现有大小 (4) 匹配
    • 是的,现在可以使用了。感谢分享。根据您的解释,这也有效:' l = [] for i in encoded_images : for ii in i : l.append(ii) '
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-01-11
    • 2019-03-15
    • 2017-10-14
    • 2011-08-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多