Paper : Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
Code : official
摘要
作者根据元学习(meta learning)的表达式提出了MAML算法用来进行元知识的梯度下降,使用一阶近似的方法来避免计算损失函数的二阶导,并在小样本学习任务(few-shot learning)上取得了SOTA的成绩。作者强调MAML算法具有模型无关性,可以适用于任何基于梯度下降优化的模型上。并给出了MAML与监督学习的结合和MAML与强化学习结合的算法,强调了算法的通用性。
Meta-Learning
在此blog中,Meta-Learning 采用的是 Meta-Learning in Neural Networks: A Survey 一文中的形式化定义方式,之后的话博主会将这篇论文的blog编出来。
元学习的直观理解是 “Learning to learn”,也就是说通过多个任务的表现来改善学习算法本身。考虑对比常规的机器学习,常规监督学习的形式化表示如下:
给定训练集 D={(x1,y1),...(xN,yN)},希望找到一个预测模型 y=fθ(x) 在训练集上的最优参数解 θ∗,即通过下式求解
θ∗=argθmin(D;θ,ω)
这里使用 ω 表示该解对某些因素的依赖性,例如针对 θ 的优化器选择或针对 f 的模型的选择,常见的 ω 包括优化器的初始化(SGD的步长等),f 模型的参数初始化方法,正则化强度等等。ω 被称为是元知识(meta-knowledge),对于常规的机器学习任务来说,元知识是被人为设定的。元学习就是对元知识进行学习和优化。
学习元知识的过程形式化的表示为如下优化问题的求解过程
ωminET∼p(T)L(D,ω)
其中 T={D,L} 表示一个常规机器学习任务,假定多个常规任务都是从某个任务分布 p(T) 中采样出来的。我们希望学到的是跨任务的元知识 ω ,这些知识可以泛化到一个之前没有遇到过的数据集上,有助于模型在小样本数据集上进行学习。形式化的表述meta-training过程:
给定一个包含 M 个学习任务的数据集 D={(Dtrain(i),Dval(i))},求解问题可以表示为
ω∗=argωmaxlogp(ω∣D)
在meta-testing的过程中,首先需要根据元知识进行学习,然后再进行模型的评估,形式化表述为:
给定某训练阶段不可知的任务 j,测试模型的参数定义为
θ∗(j)=argθmaxlogp(θ∣ω∗,Dtrain(j))
从双层优化问题的角度来理解 meta-learning,meta-training可以形式化的表示为下式
ω∗=argωmini=1∑MLmeta(θ∗(i)(ω),ω,Dval(i))s.t. θ∗(i)(ω)=argθminLtask(θ,ω,Dtrain(i))
其中 Ltask 和 Lmeta 分别对应内层和外层的优化目标(损失函数)。
对于few-shot learning来说,一个常见的术语是 N-way K-shot classification,表示对于分类任务,类别总数有 N 个,每个类下面有 K 个样本。
MAML
MAML算法的前提是,存在对于模型来说存在某些参数初始化,比其他的初始化方法具有更好的迁移性,更适合做迁移学习。对于MAML算法来说,元知识 ω 表示模型的初始化参数,想要解决的问题是小样本学习的问题。小样本集意味着不能在复杂的模型上进行多轮训练,不然会产生overfit问题。MAML通过元学习的方法学到一种对新任务损失函数敏感的初始化方法,使得模型在初始化后经过较少的epoch就可以finetune到一个比较良好的表现上。
我们的目标是得到一个初始化模型参数可以经过较少的epoch获得一个良好的表现,因此,为了简化起见,假定 θ∗(i) 表示进行了一步梯度下降的结果,即
θ∗(i)=ω−ε▽ωLTi(fω)
而外层的优化目标为
ωminTi∼p(T)∑LTi(fθ∗(i))
使用SGD算法优化元知识 ω ,即
ω←ω−ϵ▽ωTi∼p(T)∑LTi(fθ∗(i))
考虑将 ω 和 θ∗(i) 都表示为向量,有
▽ωL(fθ∗)=[∂ω1∂L(fθ∗),...,∂ωK∂L(fθ∗)]T∂ωi∂L(fθ∗)=j∑∂θj∗∂L(fθ∗)∂ωi∂θj∗
根据单步SGD的前提,有
θj∗=ωj−ε∂ωj∂L(fω)
上述求导过程涉及到对梯度函数求梯度,结果存在二阶导,如下所示
∂ωi∂θj∗={1−ε∂ωj∂ωi∂2L(fω)−ε∂ωj∂ωi∂2L(fω)i=ji=j
对表达式进行一阶近似,假定 ε→0+ ,有
∂ωi∂θj∗={10i=ji=j
因此,代入结果有
▽ωL(fθ∗)≈▽θ∗L(fθ∗)
一阶近似的MAML元知识更新式子表示为
ω←ω−ϵTi∼p(T)∑▽θ∗(i)LTi(fθ∗(i))θ∗(i)=ω−ε▽ωLTi(fω)
MAML在不同任务上的应用
Few-Shot Supervised Learning

Reinforcement Learning
RL任务定义为
Ti=(qi(x1),qi(xt+1∣xt,at),LTi,Ri)
其中 q 表示初始状态分布和状态转移分布,损失函数表示为
LTi(fω)=−Exh,ah∼fω,qTi[h=1∑HRi(xh,ah)]
对于RL任务,通常使用Policy Gradient 方法进行梯度估计。

实验
作者通过实验观察到,一阶近似的性能与使用二阶导数获得的性能几乎相同,这表明MAML的大部分改进都来自目标在更新后参数值处的一阶梯度,而不是来自更新后参数值的二阶梯度。 过去的工作已经观察到ReLU神经网络在局部几乎是线性的,这表明在大多数情况下二阶导数可能接近于零,部分解释了一阶近似的良好性能。

MAML与Transfer Learning

总结
作者给出了一种基于梯度下降来学习模型初始化参数的元学习方法,方法简单,不会为元学习引入任何学习的参数。它可以与任何适合基于梯度训练的模型表示以及任何可区分的目标(包括分类,回归和强化学习)相结合。MAML的训练过程中只考虑了内部模型参数的一步更新,但是在测试时可以进行充分的finetune。这项工作是迈向一种简单通用的元学习技术的一步,该技术可应用于任何问题和模型。