Paper : Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
Code : official

摘要

作者根据元学习(meta learning)的表达式提出了MAML算法用来进行元知识的梯度下降,使用一阶近似的方法来避免计算损失函数的二阶导,并在小样本学习任务(few-shot learning)上取得了SOTA的成绩。作者强调MAML算法具有模型无关性,可以适用于任何基于梯度下降优化的模型上。并给出了MAML与监督学习的结合和MAML与强化学习结合的算法,强调了算法的通用性。

Meta-Learning

在此blog中,Meta-Learning 采用的是 Meta-Learning in Neural Networks: A Survey 一文中的形式化定义方式,之后的话博主会将这篇论文的blog编出来。

元学习的直观理解是 “Learning to learn”,也就是说通过多个任务的表现来改善学习算法本身。考虑对比常规的机器学习,常规监督学习的形式化表示如下:

给定训练集 D={(x1,y1),...(xN,yN)}\mathcal D = \{(x_1,y_1),...(x_N,y_N)\},希望找到一个预测模型 y^=fθ(x)\widehat y = f_{\theta}(x) 在训练集上的最优参数解 θ\theta^*,即通过下式求解

θ=argminθ(D;θ,ω) \theta^* = \arg\min_{\theta} \mathcal(\mathcal D;\theta,\omega)

这里使用 ω\omega 表示该解对某些因素的依赖性,例如针对 θ\theta 的优化器选择或针对 ff 的模型的选择,常见的 ω\omega 包括优化器的初始化(SGD的步长等),ff 模型的参数初始化方法,正则化强度等等。ω\omega 被称为是元知识(meta-knowledge),对于常规的机器学习任务来说,元知识是被人为设定的。元学习就是对元知识进行学习和优化。

学习元知识的过程形式化的表示为如下优化问题的求解过程

minωETp(T)L(D,ω) \min_\omega \mathbb{E}_{\mathcal{T}\sim p(\mathcal T)}\mathcal L(\mathcal D,\omega)

其中 T={D,L}\mathcal T = \{\mathcal D,\mathcal L\} 表示一个常规机器学习任务,假定多个常规任务都是从某个任务分布 p(T)p(\mathcal T) 中采样出来的。我们希望学到的是跨任务的元知识 ω\omega ,这些知识可以泛化到一个之前没有遇到过的数据集上,有助于模型在小样本数据集上进行学习。形式化的表述meta-training过程:

给定一个包含 M 个学习任务的数据集 D={(Dtrain(i),Dval(i))}\mathbb D = \{(\mathcal D^{(i)}_\text{train},\mathcal D^{(i)}_\text{val})\},求解问题可以表示为

ω=argmaxωlogp(ωD) \omega^* = \arg \max_{\omega} \log p(\omega|\mathbb D)

在meta-testing的过程中,首先需要根据元知识进行学习,然后再进行模型的评估,形式化表述为:

给定某训练阶段不可知的任务 j\mathcal j,测试模型的参数定义为

θ(j)=argmaxθlogp(θω,Dtrain(j)) {\theta^{*}}^{(j)} = \arg \max_{\theta} \log p(\theta|\omega^*,\mathcal D^{(j)}_\text{train})

从双层优化问题的角度来理解 meta-learning,meta-training可以形式化的表示为下式

ω=argminωi=1MLmeta(θ(i)(ω),ω,Dval(i))s.t. θ(i)(ω)=argminθLtask(θ,ω,Dtrain(i)) \\\omega^* = \arg\min_\omega \sum_{i=1}^M \mathcal L^\text{meta}({\theta^*}^{(i)}(\omega),\omega,\mathcal D_\text{val}^{(i)}) \\\text{s.t. }{\theta^*}^{(i)}(\omega) = \arg\min_{\theta} \mathcal L^\text{task}(\theta,\omega,\mathcal D_\text{train}^{(i)})

其中 Ltask\mathcal L^\text{task}Lmeta\mathcal L^\text{meta} 分别对应内层和外层的优化目标(损失函数)。

对于few-shot learning来说,一个常见的术语是 N-way K-shot classification,表示对于分类任务,类别总数有 N 个,每个类下面有 K 个样本。

MAML

MAML算法的前提是,存在对于模型来说存在某些参数初始化,比其他的初始化方法具有更好的迁移性,更适合做迁移学习。对于MAML算法来说,元知识 ω\omega 表示模型的初始化参数,想要解决的问题是小样本学习的问题。小样本集意味着不能在复杂的模型上进行多轮训练,不然会产生overfit问题。MAML通过元学习的方法学到一种对新任务损失函数敏感的初始化方法,使得模型在初始化后经过较少的epoch就可以finetune到一个比较良好的表现上。

我们的目标是得到一个初始化模型参数可以经过较少的epoch获得一个良好的表现,因此,为了简化起见,假定 θ(i){\theta^*}^{(i)} 表示进行了一步梯度下降的结果,即

θ(i)=ωεωLTi(fω) {\theta^*}^{(i)} = \omega - \varepsilon \triangledown_{\omega} \mathcal L_{\mathcal{T}_i}(f_\omega)

而外层的优化目标为

minωTip(T)LTi(fθ(i)) \min_\omega \sum_{\mathcal T_i\sim p(\mathcal T)}\mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}})

使用SGD算法优化元知识 ω\omega ,即

ωωϵωTip(T)LTi(fθ(i)) \omega\leftarrow \omega -\epsilon \triangledown_{\omega} \sum_{\mathcal T_i\sim p(\mathcal T)}\mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}})

考虑将 ω\omegaθ(i){\theta^*}^{(i)} 都表示为向量,有

ωL(fθ)=[L(fθ)ω1,...,L(fθ)ωK]T  L(fθ)ωi=jL(fθ)θjθjωi \\ \triangledown_{\omega} \mathcal L(f_{\theta^*}) = [\frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_1},...,\frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_K}]^\text T \\\; \\ \frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_i} = \sum_j \frac{\partial \mathcal L(f_{\theta^*})}{\partial \theta^*_j}\frac{\partial \theta^*_j}{\partial \omega_i}

根据单步SGD的前提,有

θj=ωjεL(fω)ωj \theta^*_j = \omega_j-\varepsilon \frac{\partial \mathcal L(f_\omega)}{\partial \omega_j}

上述求导过程涉及到对梯度函数求梯度,结果存在二阶导,如下所示

θjωi={1ε2L(fω)ωjωii=jε2L(fω)ωjωiij \frac{\partial \theta^*_j}{\partial \omega_i} = \left\{\begin{matrix} 1-\varepsilon \frac{\partial^2 \mathcal L(f_\omega)}{\partial \omega_j \partial \omega_i} & i= j\\ -\varepsilon \frac{\partial^2 \mathcal L(f_\omega)}{\partial \omega_j \partial \omega_i} & i\not = j \end{matrix}\right.

对表达式进行一阶近似,假定 ε0+\varepsilon \rightarrow 0^+ ,有

θjωi={1i=j0ij \frac{\partial \theta^*_j}{\partial \omega_i} = \left\{\begin{matrix} 1 & i= j\\ 0 & i\not = j \end{matrix}\right.

因此,代入结果有

ωL(fθ)θL(fθ) \\ \triangledown_{\omega} \mathcal L(f_{\theta^*}) \approx \triangledown_{\theta^*} \mathcal L(f_{\theta^*})

一阶近似的MAML元知识更新式子表示为

ωωϵTip(T)θ(i)LTi(fθ(i))θ(i)=ωεωLTi(fω) \omega\leftarrow \omega -\epsilon \sum_{\mathcal T_i\sim p(\mathcal T)} \triangledown_{{\theta^*}^{(i)}} \mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) \\ {\theta^*}^{(i)} = \omega - \varepsilon \triangledown_{\omega} \mathcal L_{\mathcal{T}_i}(f_\omega)

MAML在不同任务上的应用

Few-Shot Supervised Learning

MAML

Reinforcement Learning

RL任务定义为

Ti=(qi(x1),qi(xt+1xt,at),LTi,Ri) \mathcal T_i = (q_i(x_1),q_i(x_{t+1}|x_t,a_t),\mathcal L_{\mathcal T_i},R_i)

其中 qq 表示初始状态分布和状态转移分布,损失函数表示为

LTi(fω)=Exh,ahfω,qTi[h=1HRi(xh,ah)] \mathcal L_{\mathcal T_i}(f_\omega) = -\mathbb E_{x_h,a_h\sim f_\omega,q_{\mathcal T_i}}[\sum_{h=1}^HR_i(x_h,a_h)]

对于RL任务,通常使用Policy Gradient 方法进行梯度估计。

MAML

实验

作者通过实验观察到,一阶近似的性能与使用二阶导数获得的性能几乎相同,这表明MAML的大部分改进都来自目标在更新后参数值处的一阶梯度,而不是来自更新后参数值的二阶梯度。 过去的工作已经观察到ReLU神经网络在局部几乎是线性的,这表明在大多数情况下二阶导数可能接近于零,部分解释了一阶近似的良好性能。

MAML

MAML与Transfer Learning

MAML

总结

作者给出了一种基于梯度下降来学习模型初始化参数的元学习方法,方法简单,不会为元学习引入任何学习的参数。它可以与任何适合基于梯度训练的模型表示以及任何可区分的目标(包括分类,回归和强化学习)相结合。MAML的训练过程中只考虑了内部模型参数的一步更新,但是在测试时可以进行充分的finetune。这项工作是迈向一种简单通用的元学习技术的一步,该技术可应用于任何问题和模型。

相关文章:

  • 2021-07-20
  • 2021-11-09
  • 2022-01-14
  • 2021-04-05
  • 2022-12-23
  • 2021-05-15
  • 2021-05-31
  • 2021-07-23
猜你喜欢
  • 2022-12-23
  • 2021-05-05
  • 2021-04-10
  • 2021-08-13
  • 2022-01-10
  • 2021-04-28
相关资源
相似解决方案