【发布时间】:2023-07-17 15:57:02
【问题描述】:
我一直在 caltech silhouettes 数据集上实现 VAE 和 IWAE 模型,但我遇到了一个问题,即 VAE 的性能略微优于 IWAE(测试 LL ~120 的 VAE,~133 的 IWAE!)。根据here产生的理论和实验,我不认为应该是这种情况。
我希望有人能在我的实施方式中找到一些导致这种情况的问题。
我用来近似q 和p 的网络与上述论文附录中详述的网络相同。模型的计算部分如下:
data_k_vec = data.repeat_interleave(K,0) # Generate K samples (in my case K=50 is producing this behavior)
mu, log_std = model.encode(data_k_vec)
z = model.reparameterize(mu, log_std) # z = mu + torch.exp(log_std)*epsilon (epsilon ~ N(0,1))
decoded = model.decode(z) # this is the sigmoid output of the model
log_prior_z = torch.sum(-0.5 * z ** 2, 1)-.5*z.shape[1]*T.log(torch.tensor(2*np.pi))
log_q_z = compute_log_probability_gaussian(z, mu, log_std) # Definitions below
log_p_x = compute_log_probability_bernoulli(decoded,data_k_vec)
if model_type == 'iwae':
log_w_matrix = (log_prior_z + log_p_x - log_q_z).view(-1, K)
elif model_type =='vae':
log_w_matrix = (log_prior_z + log_p_x - log_q_z).view(-1, 1)*1/K
log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]
ws_matrix = torch.exp(log_w_minus_max)
ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)
ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)
loss = -torch.sum(ws_sum_per_datapoint) # value of loss that gets returned to training function. loss.backward() will get called on this value
这里是似然函数。我不得不对 bernoulli LL 大惊小怪,以免在训练期间得到 nan
def compute_log_probability_gaussian(obs, mu, logstd, axis=1):
return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi))
def compute_log_probability_bernoulli(theta, obs, axis=1): # Add 1e-18 to avoid nan appearances in training
return torch.sum(obs*torch.log(theta+1e-18) + (1-obs)*torch.log(1-theta+1e-18), axis)
在此代码中,使用了一个“快捷方式”,即在 model_type=='iwae' 的情况下计算每行 K=50 个样本的逐行重要性权重,而在 model_type=='vae' 的情况下计算重要性权重正在计算每行中剩余的单个值,因此它最终计算的权重为 1。也许这就是问题所在?
任何和所有的帮助都是巨大的——我认为解决 nan 问题会让我永远摆脱困境,但现在我遇到了这个新问题。
编辑:
应该补充一点,培训计划与上面链接的论文中的相同。也就是说,对于i=0....7 的每一轮训练2**i epochs,学习率为1e-4 * 10**(-i/7)
【问题讨论】:
标签: machine-learning pytorch autoencoder unsupervised-learning bayesian-networks