什么是元学习

元学习是目前人工智能领域一个令人振奋的研究领域。大量的研究结果表明元学习已经在人工智能领域上取得了重大突破。在正式介绍元学习之前,先来看看传统的人工智能模型是如何工作的。

近几年来,深度学习迅速发展,出现了生成式对抗网络(GAN)和胶囊网络(capsule network)等优秀的算法。但深度神经网络存在的一个问题是,我们需要有一个很大的训练集来训练我们的模型,当我们只有很少的数据集时,它很大程度上会失败。再比如说,我们训练了一个深度学习模型来执行任务 A,现在,当我们有一个新的任务 B,并且与 A 密切相关,但是我们不能使用相同的模型。我们需要为任务 B 从头开始训练模型,因此对于每个任务,我们都需要从头开始训练模型,忽略了任务间的相关性。

事实上,人脑中的学习机制具备一种能力。在面对不同的任务时,人脑的学习机制并不相同。即使面对一个新的任务,人们往往也可以很快找到其学习方式。这种可以动态调整学习方式的能力,称为元学习。元学习产生了一个多功能的人工智能模型,它可以学习执行各种任务,而不需要从头开始训练它们。对于一个新的相关任务,它可以利用从以前的任务中获得的学习能力而不必从头开始训练它们。

MAML 算法

深度学习模型通过梯度的反向传播进行学习。然而,基于梯度的优化既不能应对少量的训练样本,也不能在少量的优化步骤内收敛。那么有没有一种方法可以调整优化算法,让模型在少量样本的情况下就能很好的学习?这就是基于优化方法(optimization-based)的元学习算法。

基于优化方法的元学习中,MAML可以说是当中最有名的了,它是一种相当通用的优化算法,可以与基于梯度更新的算法兼容。

其主要思想是:训练模型的初始化参数(initial parameter),使模型能在来自新任务的少量数据上对参数执行数次(1次或多次)的梯度更新后能得到最佳的表现。

MAML 算法的流程

假设所有的任务都来 自于一个任务空间,其分布为 \(p(\mathcal{T})\),我们可以在这个任务空间的所有任务上学习一种通用的表示,这种表示可以经过梯度下降方法在一个特定的单任务上进行精调(fine-tune)。假设一个模型为 \(f(\theta)\),如果我们让这个模型适应到一个新任务 \(\mathcal{T}(m)\)上, 通过一步或多步的梯度下降更新,学习到的任务适配参数为 \[ \theta_{m}^{\prime}=\theta-\alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_{m}}\left(f_{\theta}\right) \] 其中 \(\alpha\) 为学习率。这里 \(\theta_{m}^{\prime}\) 可以理解为关于 \(\theta\) 的函数,而不是真正的参数更新。

MAML的目标是学习一个参数 \(\theta\) 使得其经过一个梯度迭代就可以在新任务上达到最好的性能。 \[ \min _{\theta} \sum_{\mathcal{T}_{m} \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_{m}}\left(f\left(\theta_{m}^{\prime}\right)\right)=\sum_{\mathcal{T}_{m} \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_{m}}\left(f\left({\theta}-\alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_{m}}\left(f_{\theta}\right)\right)\right) \] 在所有任务上的元优化(Meta-Optimization)也采用梯度下降来进行优化,即

\[ \theta \leftarrow \theta-\beta \nabla_{\theta} \sum_{m=1}^{M} \mathcal{L}_{\mathcal{T}_{m}}\left(f_{\theta_{m}^{\prime}}\right) \] 其中 \(\beta\) 为元学习率,这里为一个真正的参数更新步骤。这里需要计算关于 \(\theta\) 的二阶梯度,但用一级近似通常也可以达到比较好的性能。

算法描述如下:

算法步骤:

  1. 随机初始化参数 \(\theta\)

  2. 从任务分布空间 \(p(\mathcal{T})\) 选取若干个任务

    在监督学习任务中,任务定义为 \[ \mathscr{T}_{i} \triangleq\left\{p_{i}(\mathbf{x}), p_{i}(\mathbf{y} \mid \mathbf{x}), \mathscr{L}_{i}\right\} \] 其中\[p_{i}(\mathbf{x}), p_{i}(\mathbf{y} \mid \mathbf{x})\]对应了真实的数据生成分布(通常无法得到,但是能够在训练的过程中反映出来),\(\mathscr{L}_{i}\) 表示的是损失函数。不同任务的区别可能是这三个中的某个或者多个不同。

  3. 更新任务 \[\mathcal{T}_{i}\] 的参数得到 \[\theta_{i}^{\prime}\]

  4. 根据内循环得到的 \(\theta_{i}^{\prime}\) 在其任务\(\mathcal{T}_i\) 上对应的测试集上计算损失(loss),并对初始化参数 \(\theta\) 进行梯度更新

算法过程如下图所示:

\(\theta\) 是初始化参数,也就是元学习的目标,通过不同任务的更新方向去更新 \(\theta\),使得 \(\theta\) 达到一个敏感(sensitive)位置,也就是说任务在这个位置只需要做一步或者几步的梯度下降就会引起损失函数的较大变化,迅速达到该任务的“最优”位置。

MAML 的简单例子

为了理解 MAML 算法是如何寻找一个更好且鲁棒的初始化参数 \(\theta\) 去泛化这些任务的。我们从头实现一个 MAML 算法。

为了方便理解,这里我们考虑一个二分类任务,随机生成一些训练数据,并用一个单层的神经网络去做训练。

生成训练数据

1
2
3
4
5
6
import numpy as np

def sample_points(k):
x = np.random.rand(k, 50)
y = np.random.choice([0, 1], size=k, p = [.5, .5]).reshape([-1, 1])
return x, y

测试一下函数

1
2
3
x, y = sample_points(10)
print(x[0])
print(y[0])
1
2
3
4
5
6
7
8
9
10
[0.42100647 0.69403514 0.67855943 0.57774404 0.15666203 0.91099776
0.01246907 0.08118128 0.07718763 0.57900541 0.35844089 0.03127173
0.5760209 0.71506323 0.94122599 0.95702823 0.35303246 0.35338326
0.35029937 0.62252008 0.66232719 0.11546544 0.69487415 0.36424861
0.46909378 0.12574337 0.82647021 0.40940772 0.99659183 0.27186814
0.56939123 0.02092227 0.01067603 0.9077246 0.85177498 0.5933496
0.48722737 0.66599244 0.01389198 0.9820469 0.11796959 0.94809311
0.8243121 0.27004836 0.30353558 0.87978769 0.59162869 0.07974159
0.27372265 0.41991061]
[0]

单层神经网络

为了简单以及方便理解,我们使用只有单层的神经网络来预测输出。如下

1
2
a = np.matmul(X, theta)
YHat = sigmoid(a)

所以,我们使用MAML来寻找这个最优的参数值theta,这个参数值是可以在各个任务中通用的。这样对于一个新的任务,我们只进行几次梯度下降,在较少的时间内从几个数据点中学习。

MAML 实现

现在,我们定义一个名为MAML的类,在这个类中我们实现了MAML算法。在 init 方法中,我们将初始化所有必要的变量。然后我们定义我们的 sigmoid 激活函数。接下来我们定义我们的训练函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class MAML(object):
def __init__(self):

# 初始化任务数
self.num_tasks = 10

# 初始化每个任务的样本数
self.num_samples = 10

# 算法迭代次数
self.epochs = 10000

# 内循环的学习率
self.alpha = 0.0001

# 外循环的学习率
self.beta = 0.0001

# 随机初始化初始参数
self.theta = np.random.normal(size=50).reshape(50, 1)

def sigmoid(self, a):
return 1.0 / (1 + np.exp(-a) )

def train(self):
for e in range(self.epochs):
self.theta_ = []

for i in range(self.num_tasks):
XTrain, YTrain = sample_points(self.num_samples)

a = np.matmul(XTrain, self.theta)

YHat = self.sigmoid(a)
# 交叉熵损失函数
loss = ((np.matmul(-YTrain.T, np.log(YHat)) - np.matmul((1-YTrain.T), np.log(1-YHat))) / self.num_samples)[0][0]
# 梯度下降
gradient = np.matmul(XTrain.T, (YHat - YTrain)) / self.num_samples
self.theta_.append(self.theta - self.alpha*gradient)
# 初始化 meta 梯度
meta_gradient = np.zeros(self.theta.shape)

for i in range(self.num_tasks):
# 任务的测试集
XTest, YTest = sample_points(10)
a = np.matmul(XTest, self.theta_[i])
YPred = self.sigmoid(a)

meta_gradient += np.matmul(XTest.T, (YPred - YTest)) / self.num_samples
# 更新外循环的原参数 \theta
self.theta = self.theta-self.beta*meta_gradient/self.num_tasks

if e%1000==0:
print("Epoch {}: Loss {}\n".format(e,loss))
print('Updated Model Parameter Theta\n')
print('Sampling Next Batch of Tasks \n')
print('---------------------------------\n')

测试

1
2
model = MAML()
model.train()

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
Epoch 0: Loss 1.144232040593275

Updated Model Parameter Theta

Sampling Next Batch of Tasks

---------------------------------

Epoch 1000: Loss 0.6450955930326157

Updated Model Parameter Theta

Sampling Next Batch of Tasks

---------------------------------

Epoch 2000: Loss 0.646503969137526

Updated Model Parameter Theta

Sampling Next Batch of Tasks

---------------------------------

Epoch 3000: Loss 1.3089942317460603

Updated Model Parameter Theta

Sampling Next Batch of Tasks

---------------------------------

Epoch 4000: Loss 0.551582623398733

Updated Model Parameter Theta

Sampling Next Batch of Tasks

---------------------------------

...

MAML 算法的缺点

How to train your MAML 指出了 maml 存在的 5 个问题并给出了改进建议,以下是 maml 存在的 5 个问题

  1. Training Instabilit
  2. Second Order Derivative Cost
  3. Absence of Batch Normalization Statistic Accumulation
  4. Shared (across step) Batch Normalization Bias
  5. Shared Inner Loop (across step and across parameter) Learning Rate