title: 从0开始知识蒸馏
subtitle: 从0开始知识蒸馏
date: 2021-07-27
author: NSX
catalog: true
tags:

  • 知识蒸馏

翻译自 https://keras.io/examples/vision/knowledge_distillation/#train-student-from-scratch-for-comparison

View in ColabGitHub source

更多关于蒸馏和模型推理加速的知识可参考博客《预训练模型参数量越来越大?这里有你需要的BERT推理加速技术指南

知识蒸馏简介

知识蒸馏是一个模型压缩的过程,其中训练一个小的(学生)模型来匹配一个大的预训练(教师)模型。通过最小化损失函数将知识从教师模型转移到学生,旨在匹配软化的教师逻辑和真实标签。

通过在 softmax 中应用“温度”缩放函数来软化对数,有效地平滑概率分布并揭示老师学到的类间关系。

设置

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers
  4. import numpy as np

构造 Distiller()

自定义Distiller()类,覆盖Model方法train_steptest_step以及compile()。为了使用蒸馏器,我们需要:

  • 已经训好的教师模型
  • 要训练的学生模型
  • 关于学生预测和 ground-truth 之间差异的学生损失函数
  • A distillation loss function, along with a temperature, on the difference between the soft student predictions and the soft teacher labels
  • 一个alpha因素加权学生和蒸馏损失
  • 学生和(可选)指标的优化器来评估性能

在该train_step方法中,我们执行教师和学生两者的 forward pass,计算student_lossdistillation_loss的加权损失(alpha1 - alpha),并执行 backward pass。Note: only the student weights are updated, and therefore we only calculate the gradients for the student weights.

test_step方法中,我们在提供的数据集上评估学生模型。

  1. class Distiller(keras.Model):
  2. def __init__(self, student, teacher):
  3. super(Distiller, self).__init__()
  4. self.teacher = teacher
  5. self.student = student
  6. def compile(
  7. self,
  8. optimizer,
  9. metrics,
  10. student_loss_fn,
  11. distillation_loss_fn,
  12. alpha=0.1,
  13. temperature=3,
  14. ):
  15. """ Configure the distiller.
  16. Args:
  17. optimizer: Keras optimizer for the student weights
  18. metrics: Keras metrics for evaluation
  19. student_loss_fn: Loss function of difference between student
  20. predictions and ground-truth
  21. distillation_loss_fn: Loss function of difference between soft
  22. student predictions and soft teacher predictions
  23. alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
  24. temperature: Temperature for softening probability distributions.
  25. Larger temperature gives softer distributions.
  26. """
  27. super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
  28. self.student_loss_fn = student_loss_fn
  29. self.distillation_loss_fn = distillation_loss_fn
  30. self.alpha = alpha
  31. self.temperature = temperature
  32. def train_step(self, data):
  33. # Unpack data
  34. x, y = data
  35. # Forward pass of teacher
  36. teacher_predictions = self.teacher(x, training=False)
  37. with tf.GradientTape() as tape:
  38. # Forward pass of student
  39. student_predictions = self.student(x, training=True)
  40. # Compute losses
  41. student_loss = self.student_loss_fn(y, student_predictions)
  42. distillation_loss = self.distillation_loss_fn(
  43. tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
  44. tf.nn.softmax(student_predictions / self.temperature, axis=1),
  45. )
  46. loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
  47. # Compute gradients
  48. trainable_vars = self.student.trainable_variables
  49. gradients = tape.gradient(loss, trainable_vars)
  50. # Update weights
  51. self.optimizer.apply_gradients(zip(gradients, trainable_vars))
  52. # Update the metrics configured in `compile()`.
  53. self.compiled_metrics.update_state(y, student_predictions)
  54. # Return a dict of performance
  55. results = {m.name: m.result() for m in self.metrics}
  56. results.update(
  57. {"student_loss": student_loss, "distillation_loss": distillation_loss}
  58. )
  59. return results
  60. def test_step(self, data):
  61. # Unpack the data
  62. x, y = data
  63. # Compute predictions
  64. y_prediction = self.student(x, training=False)
  65. # Calculate the loss
  66. student_loss = self.student_loss_fn(y, y_prediction)
  67. # Update the metrics.
  68. self.compiled_metrics.update_state(y, y_prediction)
  69. # Return a dict of performance
  70. results = {m.name: m.result() for m in self.metrics}
  71. results.update({"student_loss": student_loss})
  72. return results

创建学生和教师模型

最初,我们创建了一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用Sequential(),但可以是任何 Keras 模型。

  1. # Create the teacher
  2. teacher = keras.Sequential(
  3. [
  4. keras.Input(shape=(28, 28, 1)),
  5. layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
  6. layers.LeakyReLU(alpha=0.2),
  7. layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
  8. layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
  9. layers.Flatten(),
  10. layers.Dense(10),
  11. ],
  12. name="teacher",
  13. )
  14. # Create the student
  15. student = keras.Sequential(
  16. [
  17. keras.Input(shape=(28, 28, 1)),
  18. layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
  19. layers.LeakyReLU(alpha=0.2),
  20. layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
  21. layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
  22. layers.Flatten(),
  23. layers.Dense(10),
  24. ],
  25. name="student",
  26. )
  27. # Clone student for later comparison
  28. student_scratch = keras.models.clone_model(student)

准备数据集

用于训练教师和提炼教师的数据集是 MNIST,该过程对于任何其他数据集都是等效的,例如CIFAR-10,具有合适的模型选择。学生和教师都在训练集上接受训练,并在测试集上进行评估。

  1. # Prepare the train and test dataset.
  2. batch_size = 64
  3. (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
  4. # Normalize data
  5. x_train = x_train.astype("float32") / 255.0
  6. x_train = np.reshape(x_train, (-1, 28, 28, 1))
  7. x_test = x_test.astype("float32") / 255.0
  8. x_test = np.reshape(x_test, (-1, 28, 28, 1))

Train the teacher

在知识蒸馏中,我们假设老师是经过培训和固定的。因此,我们首先以通常的方式在训练集上训练教师模型。

  1. # Train teacher as usual
  2. teacher.compile(
  3. optimizer=keras.optimizers.Adam(),
  4. loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  5. metrics=[keras.metrics.SparseCategoricalAccuracy()],
  6. )
  7. # Train and evaluate teacher on data.
  8. teacher.fit(x_train, y_train, epochs=5)
  9. teacher.evaluate(x_test, y_test)
  10. Epoch 1/5
  11. 1875/1875 [==============================] - 248s 132ms/step - loss: 0.2438 - sparse_categorical_accuracy: 0.9220
  12. Epoch 2/5
  13. 1875/1875 [==============================] - 263s 140ms/step - loss: 0.0881 - sparse_categorical_accuracy: 0.9738
  14. Epoch 3/5
  15. 1875/1875 [==============================] - 245s 131ms/step - loss: 0.0650 - sparse_categorical_accuracy: 0.9811
  16. Epoch 5/5
  17. 363/1875 [====>.........................] - ETA: 3:18 - loss: 0.0555 - sparse_categorical_accuracy: 0.9839

Distill teacher to student

我们已经训练了教师模型,我们只需要初始化一个 Distiller(student, teacher)实例,compile()它具有所需的损失、超参数和优化器,并将教师提炼给学生。

  1. # Initialize and compile distiller
  2. distiller = Distiller(student=student, teacher=teacher)
  3. distiller.compile(
  4. optimizer=keras.optimizers.Adam(),
  5. metrics=[keras.metrics.SparseCategoricalAccuracy()],
  6. student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  7. distillation_loss_fn=keras.losses.KLDivergence(),
  8. alpha=0.1,
  9. temperature=10,
  10. )
  11. # Distill teacher to student
  12. distiller.fit(x_train, y_train, epochs=3)
  13. # Evaluate student on test dataset
  14. distiller.evaluate(x_test, y_test)
  15. Epoch 1/3
  16. 1875/1875 [==============================] - 242s 129ms/step - sparse_categorical_accuracy: 0.9761 - student_loss: 0.1526 - distillation_loss: 0.0226
  17. Epoch 2/3
  18. 1875/1875 [==============================] - 281s 150ms/step - sparse_categorical_accuracy: 0.9863 - student_loss: 0.1384 - distillation_loss: 0.0185
  19. Epoch 3/3
  20. 399/1875 [=====>........................] - ETA: 3:27 - sparse_categorical_accuracy: 0.9896 - student_loss: 0.1300 - distillation_loss: 0.0182

从头开始训练学生进行比较

我们还可以在没有老师的情况下从头开始训练一个等效的学生模型,以评估通过知识蒸馏获得的性能提升。

  1. # Train student as doen usually
  2. student_scratch.compile(
  3. optimizer=keras.optimizers.Adam(),
  4. loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  5. metrics=[keras.metrics.SparseCategoricalAccuracy()],
  6. )
  7. # Train and evaluate student trained from scratch.
  8. student_scratch.fit(x_train, y_train, epochs=3)
  9. student_scratch.evaluate(x_test, y_test)
  10. Epoch 1/3
  11. 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4731 - sparse_categorical_accuracy: 0.8550
  12. Epoch 2/3
  13. 1875/1875 [==============================] - 4s 2ms/step - loss: 0.0966 - sparse_categorical_accuracy: 0.9710
  14. Epoch 3/3
  15. 1875/1875 [==============================] - 4s 2ms/step - loss: 0.0750 - sparse_categorical_accuracy: 0.9773
  16. 313/313 [==============================] - 0s 963us/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9778
  17. [0.06905383616685867, 0.9778000116348267]

如果教师接受了 5 个完整的 epochs 训练,而学生在这个教师身上被提炼了 3 个完整的 epochs,那么在这个例子中,与从头开始训练相同的学生模型相比,甚至与教师本身相比,都得到了性能提升。