【问题标题】:Extracting the encoded representations of an image in PyTorch?在 PyTorch 中提取图像的编码表示?
【发布时间】:2020-06-03 03:41:03
【问题描述】:

我有一个用 PyTorch 制作的模板自动编码器神经网络模型,我在 Omniglot 数据集上使用它。我想提取图像的编码表示,但我不确定如何。

# Load data
mean = 0.5
std = 0.5
batch_size = 128
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,))
])
dataset = Omniglot('.', download=True, transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define Autoencoder
class Autoencoder(nn.Module):

    def __init__(self, n=64):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(105*105, 256, bias=True),
            nn.ReLU(True),
            nn.Linear(256, 64, bias=True),
            nn.ReLU(True),
            nn.Linear(64, n, bias=True),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.Linear(n, 64, bias=True),
            nn.ReLU(True),
            nn.Linear(64, 256, bias=True),
            nn.ReLU(True),
            nn.Linear(256, 105*105, bias=True),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def train(num_epochs, dataloader, model, criterion, optimizer):
    for epoch in range(num_epochs):
        for data in dataloader:
            img, label = data
            img = img.view(img.size(0), -1)
            img = Variable(img).cuda()
            output = model(img)
            loss = criterion(output, img)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return model

# Train model
num_epochs = 25
learning_rate = 1e-3
model = Autoencoder().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
trained_model = train(num_epochs, dataloader, model, criterion, optimizer)

【问题讨论】:

    标签: neural-network pytorch autoencoder


    【解决方案1】:

    您可以简单地在forward 函数中返回编码输出,如下所示:

    class Autoencoder(nn.Module):
    ...
        def forward(self, x):
            x = self.encoder(x)
            encoded_x = x
            x = self.decoder(x)
            return x, encoded_x
    

    稍微修改一下训练函数:

    output, encoded_output = model(img)
    

    或者你可以直接拨打encoder:

    encoded_output = model.encoder(img)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-03-09
      • 1970-01-01
      • 2020-08-22
      • 1970-01-01
      • 2019-08-06
      • 2018-01-31
      • 1970-01-01
      相关资源
      最近更新 更多