【问题标题】:Importance weighted autoencoder doing worse than VAE重要性加权自编码器比 VAE 做得更差
【发布时间】:2023-07-17 15:57:02
【问题描述】:

我一直在 caltech silhouettes 数据集上实现 VAE 和 IWAE 模型,但我遇到了一个问题,即 VAE 的性能略微优于 IWAE(测试 LL ~120 的 VAE,~133 的 IWAE!)。根据here产生的理论和实验,我不认为应该是这种情况。

我希望有人能在我的实施方式中找到一些导致这种情况的问题。

我用来近似qp 的网络与上述论文附录中详述的网络相同。模型的计算部分如下:

data_k_vec = data.repeat_interleave(K,0) # Generate K samples (in my case K=50 is producing this behavior)

mu, log_std = model.encode(data_k_vec)
z = model.reparameterize(mu, log_std) # z = mu + torch.exp(log_std)*epsilon (epsilon ~ N(0,1))
decoded = model.decode(z) # this is the sigmoid output of the model

log_prior_z = torch.sum(-0.5 * z ** 2, 1)-.5*z.shape[1]*T.log(torch.tensor(2*np.pi))
log_q_z = compute_log_probability_gaussian(z, mu, log_std) # Definitions below
log_p_x = compute_log_probability_bernoulli(decoded,data_k_vec) 

if model_type == 'iwae':
    log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, K)
elif model_type =='vae':
    log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, 1)*1/K

log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]
ws_matrix = torch.exp(log_w_minus_max)
ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

loss = -torch.sum(ws_sum_per_datapoint) # value of loss that gets returned to training function. loss.backward() will get called on this value

这里是似然函数。我不得不对 bernoulli LL 大惊小怪,以免在训练期间得到 nan

def compute_log_probability_gaussian(obs, mu, logstd, axis=1):
    return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi)) 

def compute_log_probability_bernoulli(theta, obs, axis=1): # Add 1e-18 to avoid nan appearances in training
    return torch.sum(obs*torch.log(theta+1e-18) + (1-obs)*torch.log(1-theta+1e-18), axis)

在此代码中,使用了一个“快捷方式”,即在 model_type=='iwae' 的情况下计算每行 K=50 个样本的逐行重要性权重,而在 model_type=='vae' 的情况下计算重要性权重正在计算每行中剩余的单个值,因此它最终计算的权重为 1。也许这就是问题所在?

任何和所有的帮助都是巨大的——我认为解决 nan 问题会让我永远摆脱困境,但现在我遇到了这个新问题。

编辑: 应该补充一点,培训计划与上面链接的论文中的相同。也就是说,对于i=0....7 的每一轮训练2**i epochs,学习率为1e-4 * 10**(-i/7)

【问题讨论】:

    标签: machine-learning pytorch autoencoder unsupervised-learning bayesian-networks


    【解决方案1】:

    K-样本重要性加权 ELBO 是

    $$ \textrm{IW-ELBO}(x,K) = \log \sum_{k=1}^K \frac{p(x \vert z_k) p(z_k)}{q(z_k;x )}$$

    对于 IWAE,有来自每个数据点 xK 样本,因此您希望通过摊销推理网络获得相同的潜在统计信息 mu_z, Sigma_z,但对每个 @ 采样多个 z K 次987654326@.

    因此计算data_k_vec = data.repeat_interleave(K,0) 的前向传递在计算上是浪费的,您应该为每个原始数据点计算一次前向传递,然后重复推理网络输出的统计信息进行采样:

    mu = torch.repeat_interleave(mu,K,0)
    log_std = torch.repeat_interleave(log_std,K,0)
    

    然后采样z_k。现在重复您的数据点data_k_vec = data.repeat_interleave(K,0),并使用生成的张量有效地评估每个重要性样本z_k 的条件p(x |z_k)

    请注意,在计算 IW-ELBO 时,您可能还需要使用 logsumexp 运算以实现数值稳定性。我不太清楚您帖子中的log_w_matrix 计算发生了什么,但这是我要做的:

    log_pz = ...
    log_qzCx = ....
    log_pxCz = ...
    
    log_iw = log_pxCz + log_pz - log_qzCx
    log_iw = log_iw.reshape(-1, K)
    iwelbo = torch.logsumexp(log_iw, dim=1) - np.log(K)
    

    编辑:实际上在考虑了一下并使用分数函数标识之后,您可以将 IWAE 梯度解释为标准单样本梯度的重要性加权估计,因此 OP 中用于计算重要性权重的方法是等价的(如果有点浪费的话),前提是您在标准化重要性权重周围放置一个stop_gradient 运算符,您称之为w_norm。所以我的主要问题是缺少这个stop_gradient 运算符。

    【讨论】:

      最近更新 更多