layout: post
title: 元学习入门
subtitle: 元学习入门
date: 2021-11-15
author: NSX
header-img: img/post-bg-ios9-web.jpg
catalog: true
tags:

  • Meta-Learning

元学习入门

转载自https://zhuanlan.zhihu.com/p/136975128

以下是本文的主要框架:

  1. Introduction
  2. Meta Learning实施——以MAML为例
  3. Reptile
  4. What’s more

1. Introduction

通常在机器学习里,我们会使用某个场景的大量数据来训练模型;然而当场景发生改变,模型就需要重新训练。但是对于人类而言,一个小朋友成长过程中会见过许多物体的照片,某一天,当Ta(第一次)仅仅看了几张狗的照片,就可以很好地对狗和其他物体进行区分。

元学习Meta Learning,含义为学会学习,即learn to learn,就是带着这种对人类这种“学习能力”的期望诞生的。Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,如:

  • 让Alphago迅速学会下象棋
  • 让一个猫咪图片分类器,迅速具有分类其他物体的能力

需要注意的是,虽然同样有“预训练”的意思在里面,但是元学习的内核区别于迁移学习(Transfer Learning),关于他们的区别,我会在下文进行阐述。

接下来,我们通过对比机器学习和元学习这两个概念的要素来加深对元学习这个概念的理解。

2021-11-15-元学习入门 - 图1

在机器学习中,训练单位是一条数据,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。在元学习中,训练单位分层级了,第一层训练单位是任务,也就是说,元学习中要准备许多任务来进行学习,第二层训练单位才是每个任务对应的数据。

二者的目的都是找一个Function,只是两个Function的功能不同,要做的事情不一样。机器学习中的Function直接作用于特征和标签,去寻找特征与标签之间的关联;而元学习中的Function是用于寻找新的f,新的f才会应用于具体的任务。有种不同阶导数的感觉。又有种老千层饼的感觉,你看到我在第二层,你把我想象成第一层,而其实我在第五层。。。

2. Meta Learning实施——以MAML为例

我们先对比机器学习的过程来进一步理解元学习。如下图所示,机器学习的一般过程如下:

  • 设计网络网络结构,如CNN、RNN等;
  • 选定某个分布来初始化参数;(以上其实决定了初始的f的长相,选择不同的网络结构或参数相当于定义了不同的f);
  • 喂训练数据,根据选定的Loss Function计算Loss;
  • 梯度下降,逐步更新 ;
  • 得到最终的f

2021-11-15-元学习入门 - 图2机器学习过程,引自李宏毅《深度学习》

其中,红色方框里的“配置”都是由人为设计的,我们又叫做“超参数“。Meta Learning中希望把这些配置,如网络结构,参数初始化,优化器等由机器自行设计(注:此处区别于AutoML,迁移学习(Transfer Learning)和终身学习(Life Long Learning) ),使网络有更强的学习能力和表现。

上文已经提到,【元学习中要准备许多任务来进行学习,而每个任务又有各自的训练集和测试集】。我们结合一个具体的任务,来介绍元学习和MAML的实施过程。

有一个图像数据集叫Omniglot:github.com/brendenlake/。Omniglot包含1623个不同的火星文字符,每个字符包含20个手写的case。这个任务是判断每个手写的case属于哪一个火星文字符。

如果我们要进行N-ways,K-shot(数据中包含N个字符类别,每个字符有K张图像)的一个图像分类任务。比如20-ways,1-shot分类的意思是说,要做一个20分类,但是每个分类下只有1张图像的任务。我们可以依据Omniglot构建很多N-ways,K-shot任务,这些任务将作为元学习的任务来源。构建的任务分为训练任务(Train Task),测试任务(Test Task)。特别地,每个任务包含自己的训练数据、测试数据,在元学习里,分别称为Support Set和Query Set。

MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)。我们通过(许多)N-ways,K-shot的任务(训练任务)进行元学习的训练,使得模型学习到“先验知识”(初始化的参数)。这个“先验知识”在新的N-ways,K-shot任务上可以表现的更好。

接下来介绍MAML的算法流程:

2021-11-15-元学习入门 - 图3MAML算法流程

当然,在“预训练”阶段,也可以sample出1个batch的几个任务,那么在更新meta网络时,要使用sample出所有任务的梯度之和。
注意:在MAML中,meta网络与子任务的网络结构必须完全相同。

这里面有几个小问题:

  1. MAML的执行过程与model pretraining & transfer learning的区别是什么?
  2. 为何在meta网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新meta网络?
  3. 在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?

这三个问题是MAML中很核心的问题,大家可以先思考一下,我们将在后文进行解答。我们先看一下MAML的实现代码。

  1. ## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
  2. #################################################
  3. # 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
  4. # support set:5 * 1
  5. # query set:5 * 15
  6. # 训练取1个batch的任务:batch size:4
  7. # 对训练任务进行训练时,更新5次:K = 5
  8. #################################################
  9. print(support_x) # (4, 5, 21168)
  10. print(query_x) # (4, 75, 21168)
  11. print(support_y) # (4, 5, 5)
  12. print(query_y) # (4, 75, 5)
  13. print(meta_batchsz) # 4
  14. print(K) # 5
  15. model = MAML()
  16. model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
  17. class MAML:
  18. def __init__(self):
  19. pass
  20. def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
  21. """
  22. :param support_xb: [4, 5, 84*84*3]
  23. :param support_yb: [4, 5, n-way]
  24. :param query_xb: [4, 75, 84*84*3]
  25. :param query_yb: [4, 75, n-way]
  26. :param K: 训练任务的网络更新步数
  27. :param meta_batchsz: 任务数,4
  28. """
  29. self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
  30. training = True if mode is 'train' else False
  31. def meta_task(input):
  32. """
  33. :param support_x: [setsz, 84*84*3] (5, 21168)
  34. :param support_y: [setsz, n-way] (5, 5)
  35. :param query_x: [querysz, 84*84*3] (75, 21168)
  36. :param query_y: [querysz, n-way] (75, 5)
  37. :param training: training or not, for batch_norm
  38. :return:
  39. """
  40. support_x, support_y, query_x, query_y = input
  41. query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
  42. ## 第0次对网络进行更新
  43. support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
  44. support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
  45. support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
  46. tf.argmax(support_y, axis=1))
  47. grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
  48. gvs = dict(zip(self.weights.keys(), grads))
  49. # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
  50. fast_weights = dict(zip(self.weights.keys(), \
  51. [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
  52. # 使用梯度更新后的参数对quert set进行前向计算
  53. query_pred = self.forward(query_x, fast_weights, training)
  54. query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
  55. query_preds.append(query_pred)
  56. query_losses.append(query_loss)
  57. # 第1到 K-1次对网络进行更新
  58. for _ in range(1, K):
  59. loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
  60. labels=support_y)
  61. grads = tf.gradients(loss, list(fast_weights.values()))
  62. gvs = dict(zip(fast_weights.keys(), grads))
  63. fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
  64. for key in fast_weights.keys()]))
  65. query_pred = self.forward(query_x, fast_weights, training)
  66. query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
  67. # 子网络更新K次,记录每一次queryset的结果
  68. query_preds.append(query_pred)
  69. query_losses.append(query_loss)
  70. for i in range(K):
  71. query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
  72. tf.argmax(query_y, axis=1)))
  73. result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
  74. return result
  75. # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
  76. out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
  77. result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
  78. dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
  79. support_pred_tasks, support_loss_tasks, support_acc_tasks, \
  80. query_preds_tasks, query_losses_tasks, query_accs_tasks = result
  81. if mode is 'train':
  82. self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
  83. self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
  84. for j in range(K)]
  85. self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
  86. self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
  87. for j in range(K)]
  88. # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
  89. optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
  90. gvs = optimizer.compute_gradients(self.query_losses[-1])
  91. # def ********

接下来回答一下上面的三个问题:

问题1:MAML的执行过程与model pretraining & transfer learning的区别是什么?

我们将meta learning与model pretraining的loss函数写出来。

2021-11-15-元学习入门 - 图4meta learning与model pretraining的loss函数

注意这两个loss函数的区别:

  • meta learning的L来源于训练任务上网络的参数更新过一次后(该网络更新过一次以后,网络的参数与meta网络的参数已经有一些区别),然后使用Query Set计算的loss;
  • model pretraining的L来源于同一个model的参数(只有一个),使用训练数据计算的loss和梯度对model进行更新;如果有多个训练任务,我们可以将这个参数在很多任务上进行预训练,训练的所有梯度都会直接更新到model的参数上。

看一下二者的更新过程简图:

2021-11-15-元学习入门 - 图5meta learning与model pretraining训练过程,引自李宏毅《深度学习》

  1. MAML是使用子任务的参数,第二次更新的gradient的方向来更新参数(所以左图,第一个蓝色箭头的方向与第二个绿色箭头的方向平行;左图第二个蓝色箭头的方向与第二个橘色箭头的方向平行)
  2. 而model pretraining是使用子任务第一步更新的gradient的方向来更新参数(子任务的梯度往哪个方向走,model的参数就往哪个方向走)。

从sense上直观理解:

  • model pretraining最小化当前的model(只有一个)在所有任务上的loss,所以model pretraining希望找到一个在所有任务(实际情况往往是大多数任务)上都表现较好的一个初始化参数,这个参数要在多数任务上当前表现较好。
  • meta learning最小化每一个子任务训练一步之后,第二次计算出的loss,用第二步的gradient更新meta网络,这代表了什么呢?子任务从【状态0】,到【状态1】,我们希望状态1的loss小,说明meta learning更care的是初始化参数未来的潜力。

一个关注当下,一个关注潜力。

  • 如下图所示,model pretraining找到的参数 2021-11-15-元学习入门 - 图6 ,在两个任务上当前的表现比较好(当下好,但训练之后不保证好);
  • 而MAML的参数 2021-11-15-元学习入门 - 图7在两个子任务当前的表现可能都不是很好,但是如果在两个子任务上继续训练下去,可能会达到各自任务的局部最优(潜力好)。

2021-11-15-元学习入门 - 图8引自李宏毅《深度学习》

这里有一个toy example可以表现MAML的执行过程与model pretraining & transfer learning的区别。

训练任务:给定N个函数,y = asinx + b(通过给a和b不同的取值可以得到很多sin函数),从每个函数中sample出K个点,用sample出的K个点来预估最初的函数,即求解a和b的值。

训练过程:用这N个训练任务sample出的数据点分别通过MAML与model pretraining训练网络,得到预训练的参数。

如下图,用橘黄色的sin函数作为测试任务,三角形的点是测试任务中sample出的样本点,在测试任务中,我们希望用sample出的样本点还原橘黄色的线。

2021-11-15-元学习入门 - 图9Toy example,引自李宏毅《深度学习》

  • model pretraining的结果,在测试任务上,在finetuning之前,绿色线是一条水平线,finetuning之后还原的线基本还是一条水平线。因为在预训练的时候,有很多sin函数,model pretraining希望找到一个在所有任务上都效果较好的初始化结果,但是许多sin函数波峰和波谷重叠起来,基本就是一条水平线。用这个初始化的结果取finetuning,得到的结果仍然是水平线。
  • MAML的初始化结果是绿色的线,和橘黄色的线有差异。但是随着finetuning的进行,结果与橘黄色的线更加接近。

问题2:为何在meta网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新meta网络?

这个问题其实在问题1中已经进行了回答,更新一步之后,避免了meta learning陷入了和model pretraining一样的训练模式,更重要的是,可以使得meta模型更关注参数的“潜力”。

问题3:在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?

李宏毅老师的课程中提到:

  • 只更新一次,速度比较快;因为meta learning中,子任务有很多,都更新很多次,训练时间比较久。
  • MAML希望得到的初始化参数在新的任务中finetuning的时候效果好。如果只更新一次,就可以在新任务上获取很好的表现。把这件事情当成目标,可以使得meta网络参数训练是很好(目标与需求一致)。
  • 当初始化参数应用到具体的任务中时,也可以finetuning很多次。
  • Few-shot learning往往数据较少。

那么MAML中的训练任务的网络可以更新多次后,再更新meta网络吗?

我觉得可以。直观上感觉,更新次数决定了子任务对于meta网络的影响程度,我觉得这个步数可以作为一个参数来调。

另外,即将介绍的下一个网络——Reptile,也是对训练任务网络进行多次更新的。

3. Reptile

Reptile与MAML有点像,我们先看一下Reptile的训练简图:

2021-11-15-元学习入门 - 图10Reptile训练过程,引自李宏毅《深度学习》

Reptile的训练过程如下:

2021-11-15-元学习入门 - 图11Reptile,每次sample出1个训练任务

2021-11-15-元学习入门 - 图12Reptile,每次sample出1个batch训练任务

在Reptile中:

  • 训练任务的网络可以更新多次
  • reptile不再像MAML一样计算梯度(因此带来了工程性能的提升),而是直接用一个参数 2021-11-15-元学习入门 - 图13 乘以meta网络与训练任务的网络参数的差来更新meta网络参数
  • 从效果上来看,Reptile效果与MAML基本持平

4. What’s more

元学习入门部分的文章基本就分享到这里了~

  • 从出发点上来看,元学习和model pretraining有点像,即,都是让网络具有一些先验知识。
  • 从训练过程的设计来看,元学习更关注模型的潜力,而model pretraining更注重模型当下在多数情况下的表现,效果孰好孰坏很难直接判定。这大概也就是仰望天空和脚踏实地的区别hahaha
  • 元学习除了可以初始化参数以外,还有一些设计可以帮助确定网络结构,如何更新参数等等这里有李宏毅老师的一个课程大家可以关注一下youtube.com/watch? 。

参考

一文入门元学习(Meta-Learning)(附代码)