【发布时间】:2021-07-01 13:44:17
【问题描述】:
使用 PyTorch 训练自动编码器后,如何在某个隐藏级别提取输入特征的低维嵌入?
【问题讨论】:
标签: pytorch autoencoder
使用 PyTorch 训练自动编码器后,如何在某个隐藏级别提取输入特征的低维嵌入?
【问题讨论】:
标签: pytorch autoencoder
您可以定义您的模型,使其可选地返回在前向传递期间计算的中间 pytorch 变量。简单例子:
class Autoencoder(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 3)) #reduce the size
self.decoder = nn.Sequential(
nn.Linear(3, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size),
nn.ReLU()) #reduce the size
def forward(self, x, return_encoding = False):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
if return_encoding:
return decoded,encoded
return decoded
【讨论】: