title: 从0开始知识蒸馏
subtitle: 从0开始知识蒸馏
date: 2021-07-27
author: NSX
catalog: true
tags:
- 知识蒸馏
翻译自 https://keras.io/examples/vision/knowledge_distillation/#train-student-from-scratch-for-comparison
更多关于蒸馏和模型推理加速的知识可参考博客《预训练模型参数量越来越大?这里有你需要的BERT推理加速技术指南》
知识蒸馏简介
知识蒸馏是一个模型压缩的过程,其中训练一个小的(学生)模型来匹配一个大的预训练(教师)模型。通过最小化损失函数将知识从教师模型转移到学生,旨在匹配软化的教师逻辑和真实标签。
通过在 softmax 中应用“温度”缩放函数来软化对数,有效地平滑概率分布并揭示老师学到的类间关系。
设置
import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersimport numpy as np
构造 Distiller() 类
自定义Distiller()类,覆盖Model方法train_step,test_step以及compile()。为了使用蒸馏器,我们需要:
- 已经训好的教师模型
- 要训练的学生模型
- 关于学生预测和 ground-truth 之间差异的学生损失函数
- A distillation loss function, along with a
temperature, on the difference between the soft student predictions and the soft teacher labels - 一个
alpha因素加权学生和蒸馏损失 - 学生和(可选)指标的优化器来评估性能
在该train_step方法中,我们执行教师和学生两者的 forward pass,计算student_loss和distillation_loss的加权损失(alpha与 1 - alpha),并执行 backward pass。Note: only the student weights are updated, and therefore we only calculate the gradients for the student weights.
在test_step方法中,我们在提供的数据集上评估学生模型。
class Distiller(keras.Model):def __init__(self, student, teacher):super(Distiller, self).__init__()self.teacher = teacherself.student = studentdef compile(self,optimizer,metrics,student_loss_fn,distillation_loss_fn,alpha=0.1,temperature=3,):""" Configure the distiller.Args:optimizer: Keras optimizer for the student weightsmetrics: Keras metrics for evaluationstudent_loss_fn: Loss function of difference between studentpredictions and ground-truthdistillation_loss_fn: Loss function of difference between softstudent predictions and soft teacher predictionsalpha: weight to student_loss_fn and 1-alpha to distillation_loss_fntemperature: Temperature for softening probability distributions.Larger temperature gives softer distributions."""super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fnself.alpha = alphaself.temperature = temperaturedef train_step(self, data):# Unpack datax, y = data# Forward pass of teacherteacher_predictions = self.teacher(x, training=False)with tf.GradientTape() as tape:# Forward pass of studentstudent_predictions = self.student(x, training=True)# Compute lossesstudent_loss = self.student_loss_fn(y, student_predictions)distillation_loss = self.distillation_loss_fn(tf.nn.softmax(teacher_predictions / self.temperature, axis=1),tf.nn.softmax(student_predictions / self.temperature, axis=1),)loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss# Compute gradientstrainable_vars = self.student.trainable_variablesgradients = tape.gradient(loss, trainable_vars)# Update weightsself.optimizer.apply_gradients(zip(gradients, trainable_vars))# Update the metrics configured in `compile()`.self.compiled_metrics.update_state(y, student_predictions)# Return a dict of performanceresults = {m.name: m.result() for m in self.metrics}results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})return resultsdef test_step(self, data):# Unpack the datax, y = data# Compute predictionsy_prediction = self.student(x, training=False)# Calculate the lossstudent_loss = self.student_loss_fn(y, y_prediction)# Update the metrics.self.compiled_metrics.update_state(y, y_prediction)# Return a dict of performanceresults = {m.name: m.result() for m in self.metrics}results.update({"student_loss": student_loss})return results
创建学生和教师模型
最初,我们创建了一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用Sequential(),但可以是任何 Keras 模型。
# Create the teacherteacher = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(alpha=0.2),layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),layers.Flatten(),layers.Dense(10),],name="teacher",)# Create the studentstudent = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(alpha=0.2),layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),layers.Flatten(),layers.Dense(10),],name="student",)# Clone student for later comparisonstudent_scratch = keras.models.clone_model(student)
准备数据集
用于训练教师和提炼教师的数据集是 MNIST,该过程对于任何其他数据集都是等效的,例如CIFAR-10,具有合适的模型选择。学生和教师都在训练集上接受训练,并在测试集上进行评估。
# Prepare the train and test dataset.batch_size = 64(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# Normalize datax_train = x_train.astype("float32") / 255.0x_train = np.reshape(x_train, (-1, 28, 28, 1))x_test = x_test.astype("float32") / 255.0x_test = np.reshape(x_test, (-1, 28, 28, 1))
Train the teacher
在知识蒸馏中,我们假设老师是经过培训和固定的。因此,我们首先以通常的方式在训练集上训练教师模型。
# Train teacher as usualteacher.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],)# Train and evaluate teacher on data.teacher.fit(x_train, y_train, epochs=5)teacher.evaluate(x_test, y_test)Epoch 1/51875/1875 [==============================] - 248s 132ms/step - loss: 0.2438 - sparse_categorical_accuracy: 0.9220Epoch 2/51875/1875 [==============================] - 263s 140ms/step - loss: 0.0881 - sparse_categorical_accuracy: 0.9738Epoch 3/51875/1875 [==============================] - 245s 131ms/step - loss: 0.0650 - sparse_categorical_accuracy: 0.9811Epoch 5/5363/1875 [====>.........................] - ETA: 3:18 - loss: 0.0555 - sparse_categorical_accuracy: 0.9839
Distill teacher to student
我们已经训练了教师模型,我们只需要初始化一个 Distiller(student, teacher)实例,compile()它具有所需的损失、超参数和优化器,并将教师提炼给学生。
# Initialize and compile distillerdistiller = Distiller(student=student, teacher=teacher)distiller.compile(optimizer=keras.optimizers.Adam(),metrics=[keras.metrics.SparseCategoricalAccuracy()],student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),distillation_loss_fn=keras.losses.KLDivergence(),alpha=0.1,temperature=10,)# Distill teacher to studentdistiller.fit(x_train, y_train, epochs=3)# Evaluate student on test datasetdistiller.evaluate(x_test, y_test)Epoch 1/31875/1875 [==============================] - 242s 129ms/step - sparse_categorical_accuracy: 0.9761 - student_loss: 0.1526 - distillation_loss: 0.0226Epoch 2/31875/1875 [==============================] - 281s 150ms/step - sparse_categorical_accuracy: 0.9863 - student_loss: 0.1384 - distillation_loss: 0.0185Epoch 3/3399/1875 [=====>........................] - ETA: 3:27 - sparse_categorical_accuracy: 0.9896 - student_loss: 0.1300 - distillation_loss: 0.0182
从头开始训练学生进行比较
我们还可以在没有老师的情况下从头开始训练一个等效的学生模型,以评估通过知识蒸馏获得的性能提升。
# Train student as doen usuallystudent_scratch.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],)# Train and evaluate student trained from scratch.student_scratch.fit(x_train, y_train, epochs=3)student_scratch.evaluate(x_test, y_test)Epoch 1/31875/1875 [==============================] - 4s 2ms/step - loss: 0.4731 - sparse_categorical_accuracy: 0.8550Epoch 2/31875/1875 [==============================] - 4s 2ms/step - loss: 0.0966 - sparse_categorical_accuracy: 0.9710Epoch 3/31875/1875 [==============================] - 4s 2ms/step - loss: 0.0750 - sparse_categorical_accuracy: 0.9773313/313 [==============================] - 0s 963us/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9778[0.06905383616685867, 0.9778000116348267]
如果教师接受了 5 个完整的 epochs 训练,而学生在这个教师身上被提炼了 3 个完整的 epochs,那么在这个例子中,与从头开始训练相同的学生模型相比,甚至与教师本身相比,都得到了性能提升。
