预训练BERT

:label:sec_bert-pretraining

利用 :numref:sec_bert中实现的BERT模型和 :numref:sec_bert-dataset中从WikiText-2数据集生成的预训练样本,我们将在本节中在WikiText-2数据集上对BERT进行预训练。

```{.python .input} from d2l import mxnet as d2l from mxnet import autograd, gluon, init, np, npx

npx.set_np()

  1. ```{.python .input}
  2. #@tab pytorch
  3. from d2l import torch as d2l
  4. import torch
  5. from torch import nn

首先,我们加载WikiText-2数据集作为小批量的预训练样本,用于遮蔽语言模型和下一句预测。批量大小是512,BERT输入序列的最大长度是64。注意,在原始BERT模型中,最大长度是512。

```{.python .input}

@tab all

batch_size, max_len = 512, 64 train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)

  1. ## 预训练BERT
  2. 原始BERT :cite:`Devlin.Chang.Lee.ea.2018`有两个不同模型尺寸的版本。基本模型($\text{BERT}_{\text{BASE}}$)使用12层(Transformer编码器块),768个隐藏单元(隐藏大小)和12个自注意头。大模型($\text{BERT}_{\text{LARGE}}$)使用24层,1024个隐藏单元和16个自注意头。值得注意的是,前者有1.1亿个参数,后者有3.4亿个参数。为了便于演示,我们定义了一个小的BERT,使用了2层、128个隐藏单元和2个自注意头。
  3. ```{.python .input}
  4. net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
  5. num_heads=2, num_layers=2, dropout=0.2)
  6. devices = d2l.try_all_gpus()
  7. net.initialize(init.Xavier(), ctx=devices)
  8. loss = gluon.loss.SoftmaxCELoss()

```{.python .input}

@tab pytorch

net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128], ffn_num_input=128, ffn_num_hiddens=256, num_heads=2, num_layers=2, dropout=0.2, key_size=128, query_size=128, value_size=128, hid_in_features=128, mlm_in_features=128, nsp_in_features=128) devices = d2l.try_all_gpus() loss = nn.CrossEntropyLoss()

  1. 在定义训练代码实现之前,我们定义了一个辅助函数`_get_batch_loss_bert`。给定训练样本,该函数计算遮蔽语言模型和下一句子预测任务的损失。请注意,BERT预训练的最终损失是遮蔽语言模型损失和下一句预测损失的和。
  2. ```{.python .input}
  3. #@save
  4. def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
  5. segments_X_shards, valid_lens_x_shards,
  6. pred_positions_X_shards, mlm_weights_X_shards,
  7. mlm_Y_shards, nsp_y_shards):
  8. mlm_ls, nsp_ls, ls = [], [], []
  9. for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
  10. pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
  11. nsp_y_shard) in zip(
  12. tokens_X_shards, segments_X_shards, valid_lens_x_shards,
  13. pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
  14. nsp_y_shards):
  15. # 前向传播
  16. _, mlm_Y_hat, nsp_Y_hat = net(
  17. tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
  18. pred_positions_X_shard)
  19. # 计算遮蔽语言模型损失
  20. mlm_l = loss(
  21. mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
  22. mlm_weights_X_shard.reshape((-1, 1)))
  23. mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
  24. # 计算下一句子预测任务的损失
  25. nsp_l = loss(nsp_Y_hat, nsp_y_shard)
  26. nsp_l = nsp_l.mean()
  27. mlm_ls.append(mlm_l)
  28. nsp_ls.append(nsp_l)
  29. ls.append(mlm_l + nsp_l)
  30. npx.waitall()
  31. return mlm_ls, nsp_ls, ls

```{.python .input}

@tab pytorch

@save

def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):

  1. # 前向传播
  2. _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
  3. valid_lens_x.reshape(-1),
  4. pred_positions_X)
  5. # 计算遮蔽语言模型损失
  6. mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
  7. mlm_weights_X.reshape(-1, 1)
  8. mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
  9. # 计算下一句子预测任务的损失
  10. nsp_l = loss(nsp_Y_hat, nsp_y)
  11. l = mlm_l + nsp_l
  12. return mlm_l, nsp_l, l
  1. 通过调用上述两个辅助函数,下面的`train_bert`函数定义了在WikiText-2`train_iter`)数据集上预训练BERT`net`)的过程。训练BERT可能需要很长时间。以下函数的输入`num_steps`指定了训练的迭代步数,而不是像`train_ch13`函数那样指定训练的轮数(参见 :numref:`sec_image_augmentation`)。
  2. ```{.python .input}
  3. def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  4. trainer = gluon.Trainer(net.collect_params(), 'adam',
  5. {'learning_rate': 0.01})
  6. step, timer = 0, d2l.Timer()
  7. animator = d2l.Animator(xlabel='step', ylabel='loss',
  8. xlim=[1, num_steps], legend=['mlm', 'nsp'])
  9. # 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数
  10. metric = d2l.Accumulator(4)
  11. num_steps_reached = False
  12. while step < num_steps and not num_steps_reached:
  13. for batch in train_iter:
  14. (tokens_X_shards, segments_X_shards, valid_lens_x_shards,
  15. pred_positions_X_shards, mlm_weights_X_shards,
  16. mlm_Y_shards, nsp_y_shards) = [gluon.utils.split_and_load(
  17. elem, devices, even_split=False) for elem in batch]
  18. timer.start()
  19. with autograd.record():
  20. mlm_ls, nsp_ls, ls = _get_batch_loss_bert(
  21. net, loss, vocab_size, tokens_X_shards, segments_X_shards,
  22. valid_lens_x_shards, pred_positions_X_shards,
  23. mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards)
  24. for l in ls:
  25. l.backward()
  26. trainer.step(1)
  27. mlm_l_mean = sum([float(l) for l in mlm_ls]) / len(mlm_ls)
  28. nsp_l_mean = sum([float(l) for l in nsp_ls]) / len(nsp_ls)
  29. metric.add(mlm_l_mean, nsp_l_mean, batch[0].shape[0], 1)
  30. timer.stop()
  31. animator.add(step + 1,
  32. (metric[0] / metric[3], metric[1] / metric[3]))
  33. step += 1
  34. if step == num_steps:
  35. num_steps_reached = True
  36. break
  37. print(f'MLM loss {metric[0] / metric[3]:.3f}, '
  38. f'NSP loss {metric[1] / metric[3]:.3f}')
  39. print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
  40. f'{str(devices)}')

```{.python .input}

@tab pytorch

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): net = nn.DataParallel(net, device_ids=devices).to(devices[0]) trainer = torch.optim.Adam(net.parameters(), lr=0.01) step, timer = 0, d2l.Timer() animator = d2l.Animator(xlabel=’step’, ylabel=’loss’, xlim=[1, num_steps], legend=[‘mlm’, ‘nsp’])

  1. # 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数
  2. metric = d2l.Accumulator(4)
  3. num_steps_reached = False
  4. while step < num_steps and not num_steps_reached:
  5. for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
  6. mlm_weights_X, mlm_Y, nsp_y in train_iter:
  7. tokens_X = tokens_X.to(devices[0])
  8. segments_X = segments_X.to(devices[0])
  9. valid_lens_x = valid_lens_x.to(devices[0])
  10. pred_positions_X = pred_positions_X.to(devices[0])
  11. mlm_weights_X = mlm_weights_X.to(devices[0])
  12. mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
  13. trainer.zero_grad()
  14. timer.start()
  15. mlm_l, nsp_l, l = _get_batch_loss_bert(
  16. net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
  17. pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
  18. l.backward()
  19. trainer.step()
  20. metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
  21. timer.stop()
  22. animator.add(step + 1,
  23. (metric[0] / metric[3], metric[1] / metric[3]))
  24. step += 1
  25. if step == num_steps:
  26. num_steps_reached = True
  27. break
  28. print(f'MLM loss {metric[0] / metric[3]:.3f}, '
  29. f'NSP loss {metric[1] / metric[3]:.3f}')
  30. print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
  31. f'{str(devices)}')
  1. 在预训练过程中,我们可以绘制出遮蔽语言模型损失和下一句预测损失。
  2. ```{.python .input}
  3. #@tab all
  4. train_bert(train_iter, net, loss, len(vocab), devices, 50)

用BERT表示文本

在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。下面的函数返回tokens_atokens_b中所有词元的BERT(net)表示。

```{.python .input} def getbert_encoding(net, tokens_a, tokens_b=None): tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b) token_ids = np.expand_dims(np.array(vocab[tokens], ctx=devices[0]), axis=0) segments = np.expand_dims(np.array(segments, ctx=devices[0]), axis=0) valid_len = np.expand_dims(np.array(len(tokens), ctx=devices[0]), axis=0) encoded_X, , _ = net(token_ids, segments, valid_len) return encoded_X

  1. ```{.python .input}
  2. #@tab pytorch
  3. def get_bert_encoding(net, tokens_a, tokens_b=None):
  4. tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
  5. token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)
  6. segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)
  7. valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
  8. encoded_X, _, _ = net(token_ids, segments, valid_len)
  9. return encoded_X

考虑“a crane is flying”这句话。回想一下 :numref:subsec_bert_input_rep中讨论的BERT的输入表示。插入特殊标记“<cls>”(用于分类)和“<sep>”(用于分隔)后,BERT输入序列的长度为6。因为零是“<cls>”词元,encoded_text[:, 0, :]是整个输入语句的BERT表示。为了评估一词多义词元“crane”,我们还打印出了该词元的BERT表示的前三个元素。

```{.python .input}

@tab all

tokens_a = [‘a’, ‘crane’, ‘is’, ‘flying’] encoded_text = get_bert_encoding(net, tokens_a)

词元:’‘,’a’,’crane’,’is’,’flying’,’

encoded_text_cls = encoded_text[:, 0, :] encoded_text_crane = encoded_text[:, 2, :] encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]

  1. 现在考虑一个句子“a crane driver came”和“he just left”。类似地,`encoded_pair[:, 0, :]`是来自预训练BERT的整个句子对的编码结果。注意,多义词元“crane”的前三个元素与上下文不同时的元素不同。这支持了BERT表示是上下文敏感的。
  2. ```{.python .input}
  3. #@tab all
  4. tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
  5. encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
  6. # 词元:'<cls>','a','crane','driver','came','<sep>','he','just',
  7. # 'left','<sep>'
  8. encoded_pair_cls = encoded_pair[:, 0, :]
  9. encoded_pair_crane = encoded_pair[:, 2, :]
  10. encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

在 :numref:chap_nlp_app中,我们将为下游自然语言处理应用微调预训练的BERT模型。

小结

  • 原始的BERT有两个版本,其中基本模型有1.1亿个参数,大模型有3.4亿个参数。
  • 在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。
  • 在实验中,同一个词元在不同的上下文中具有不同的BERT表示。这支持BERT表示是上下文敏感的。

练习

  1. 在实验中,我们可以看到遮蔽语言模型损失明显高于下一句预测损失。为什么?
  2. 将BERT输入序列的最大长度设置为512(与原始BERT模型相同)。使用原始BERT模型的配置,如$\text{BERT}_{\text{LARGE}}$。运行此部分时是否遇到错误?为什么?

:begin_tab:mxnet Discussions :end_tab:

:begin_tab:pytorch Discussions :end_tab: