Pretraining BERT

:label:sec_bert-pretraining

With the BERT model implemented in :numref:sec_bert and the pretraining examples generated from the WikiText-2 dataset in :numref:sec_bert-dataset, we will pretrain BERT on the WikiText-2 dataset in this section.

```{.python .input} from d2l import mxnet as d2l from mxnet import autograd, gluon, init, np, npx

npx.set_np()

  1. ```{.python .input}
  2. #@tab pytorch
  3. from d2l import torch as d2l
  4. import torch
  5. from torch import nn

To start, we load the WikiText-2 dataset as minibatches of pretraining examples for masked language modeling and next sentence prediction. The batch size is 512 and the maximum length of a BERT input sequence is 64. Note that in the original BERT model, the maximum length is 512.

```{.python .input}

@tab all

batch_size, max_len = 512, 64 train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)

  1. ## Pretraining BERT
  2. The original BERT has two versions of different model sizes :cite:`Devlin.Chang.Lee.ea.2018`.
  3. The base model ($\text{BERT}_{\text{BASE}}$) uses 12 layers (transformer encoder blocks)
  4. with 768 hidden units (hidden size) and 12 self-attention heads.
  5. The large model ($\text{BERT}_{\text{LARGE}}$) uses 24 layers
  6. with 1024 hidden units and 16 self-attention heads.
  7. Notably, the former has 110 million parameters while the latter has 340 million parameters.
  8. For demonstration with ease,
  9. we define a small BERT, using 2 layers, 128 hidden units, and 2 self-attention heads.
  10. ```{.python .input}
  11. net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
  12. num_heads=2, num_layers=2, dropout=0.2)
  13. devices = d2l.try_all_gpus()
  14. net.initialize(init.Xavier(), ctx=devices)
  15. loss = gluon.loss.SoftmaxCELoss()

```{.python .input}

@tab pytorch

net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128], ffn_num_input=128, ffn_num_hiddens=256, num_heads=2, num_layers=2, dropout=0.2, key_size=128, query_size=128, value_size=128, hid_in_features=128, mlm_in_features=128, nsp_in_features=128) devices = d2l.try_all_gpus() loss = nn.CrossEntropyLoss()

  1. Before defining the training loop,
  2. we define a helper function `_get_batch_loss_bert`.
  3. Given the shard of training examples,
  4. this function computes the loss for both the masked language modeling and next sentence prediction tasks.
  5. Note that the final loss of BERT pretraining
  6. is just the sum of both the masked language modeling loss
  7. and the next sentence prediction loss.
  8. ```{.python .input}
  9. #@save
  10. def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
  11. segments_X_shards, valid_lens_x_shards,
  12. pred_positions_X_shards, mlm_weights_X_shards,
  13. mlm_Y_shards, nsp_y_shards):
  14. mlm_ls, nsp_ls, ls = [], [], []
  15. for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
  16. pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
  17. nsp_y_shard) in zip(
  18. tokens_X_shards, segments_X_shards, valid_lens_x_shards,
  19. pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
  20. nsp_y_shards):
  21. # Forward pass
  22. _, mlm_Y_hat, nsp_Y_hat = net(
  23. tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
  24. pred_positions_X_shard)
  25. # Compute masked language model loss
  26. mlm_l = loss(
  27. mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
  28. mlm_weights_X_shard.reshape((-1, 1)))
  29. mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
  30. # Compute next sentence prediction loss
  31. nsp_l = loss(nsp_Y_hat, nsp_y_shard)
  32. nsp_l = nsp_l.mean()
  33. mlm_ls.append(mlm_l)
  34. nsp_ls.append(nsp_l)
  35. ls.append(mlm_l + nsp_l)
  36. npx.waitall()
  37. return mlm_ls, nsp_ls, ls

```{.python .input}

@tab pytorch

@save

def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):

  1. # Forward pass
  2. _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
  3. valid_lens_x.reshape(-1),
  4. pred_positions_X)
  5. # Compute masked language model loss
  6. mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
  7. mlm_weights_X.reshape(-1, 1)
  8. mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
  9. # Compute next sentence prediction loss
  10. nsp_l = loss(nsp_Y_hat, nsp_y)
  11. l = mlm_l + nsp_l
  12. return mlm_l, nsp_l, l
  1. Invoking the two aforementioned helper functions,
  2. the following `train_bert` function
  3. defines the procedure to pretrain BERT (`net`) on the WikiText-2 (`train_iter`) dataset.
  4. Training BERT can take very long.
  5. Instead of specifying the number of epochs for training
  6. as in the `train_ch13` function (see :numref:`sec_image_augmentation`),
  7. the input `num_steps` of the following function
  8. specifies the number of iteration steps for training.
  9. ```{.python .input}
  10. def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  11. trainer = gluon.Trainer(net.collect_params(), 'adam',
  12. {'learning_rate': 1e-3})
  13. step, timer = 0, d2l.Timer()
  14. animator = d2l.Animator(xlabel='step', ylabel='loss',
  15. xlim=[1, num_steps], legend=['mlm', 'nsp'])
  16. # Sum of masked language modeling losses, sum of next sentence prediction
  17. # losses, no. of sentence pairs, count
  18. metric = d2l.Accumulator(4)
  19. num_steps_reached = False
  20. while step < num_steps and not num_steps_reached:
  21. for batch in train_iter:
  22. (tokens_X_shards, segments_X_shards, valid_lens_x_shards,
  23. pred_positions_X_shards, mlm_weights_X_shards,
  24. mlm_Y_shards, nsp_y_shards) = [gluon.utils.split_and_load(
  25. elem, devices, even_split=False) for elem in batch]
  26. timer.start()
  27. with autograd.record():
  28. mlm_ls, nsp_ls, ls = _get_batch_loss_bert(
  29. net, loss, vocab_size, tokens_X_shards, segments_X_shards,
  30. valid_lens_x_shards, pred_positions_X_shards,
  31. mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards)
  32. for l in ls:
  33. l.backward()
  34. trainer.step(1)
  35. mlm_l_mean = sum([float(l) for l in mlm_ls]) / len(mlm_ls)
  36. nsp_l_mean = sum([float(l) for l in nsp_ls]) / len(nsp_ls)
  37. metric.add(mlm_l_mean, nsp_l_mean, batch[0].shape[0], 1)
  38. timer.stop()
  39. animator.add(step + 1,
  40. (metric[0] / metric[3], metric[1] / metric[3]))
  41. step += 1
  42. if step == num_steps:
  43. num_steps_reached = True
  44. break
  45. print(f'MLM loss {metric[0] / metric[3]:.3f}, '
  46. f'NSP loss {metric[1] / metric[3]:.3f}')
  47. print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
  48. f'{str(devices)}')

```{.python .input}

@tab pytorch

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): net = nn.DataParallel(net, device_ids=devices).to(devices[0]) trainer = torch.optim.Adam(net.parameters(), lr=1e-3) step, timer = 0, d2l.Timer() animator = d2l.Animator(xlabel=’step’, ylabel=’loss’, xlim=[1, num_steps], legend=[‘mlm’, ‘nsp’])

  1. # Sum of masked language modeling losses, sum of next sentence prediction
  2. # losses, no. of sentence pairs, count
  3. metric = d2l.Accumulator(4)
  4. num_steps_reached = False
  5. while step < num_steps and not num_steps_reached:
  6. for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
  7. mlm_weights_X, mlm_Y, nsp_y in train_iter:
  8. tokens_X = tokens_X.to(devices[0])
  9. segments_X = segments_X.to(devices[0])
  10. valid_lens_x = valid_lens_x.to(devices[0])
  11. pred_positions_X = pred_positions_X.to(devices[0])
  12. mlm_weights_X = mlm_weights_X.to(devices[0])
  13. mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
  14. trainer.zero_grad()
  15. timer.start()
  16. mlm_l, nsp_l, l = _get_batch_loss_bert(
  17. net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
  18. pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
  19. l.backward()
  20. trainer.step()
  21. metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
  22. timer.stop()
  23. animator.add(step + 1,
  24. (metric[0] / metric[3], metric[1] / metric[3]))
  25. step += 1
  26. if step == num_steps:
  27. num_steps_reached = True
  28. break
  29. print(f'MLM loss {metric[0] / metric[3]:.3f}, '
  30. f'NSP loss {metric[1] / metric[3]:.3f}')
  31. print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
  32. f'{str(devices)}')
  1. We can plot both the masked language modeling loss and the next sentence prediction loss
  2. during BERT pretraining.
  3. ```{.python .input}
  4. #@tab all
  5. train_bert(train_iter, net, loss, len(vocab), devices, 50)

Representing Text with BERT

After pretraining BERT, we can use it to represent single text, text pairs, or any token in them. The following function returns the BERT (net) representations for all tokens in tokens_a and tokens_b.

```{.python .input} def getbert_encoding(net, tokens_a, tokens_b=None): tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b) token_ids = np.expand_dims(np.array(vocab[tokens], ctx=devices[0]), axis=0) segments = np.expand_dims(np.array(segments, ctx=devices[0]), axis=0) valid_len = np.expand_dims(np.array(len(tokens), ctx=devices[0]), axis=0) encoded_X, , _ = net(token_ids, segments, valid_len) return encoded_X

  1. ```{.python .input}
  2. #@tab pytorch
  3. def get_bert_encoding(net, tokens_a, tokens_b=None):
  4. tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
  5. token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)
  6. segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)
  7. valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
  8. encoded_X, _, _ = net(token_ids, segments, valid_len)
  9. return encoded_X

Consider the sentence “a crane is flying”. Recall the input representation of BERT as discussed in :numref:subsec_bert_input_rep. After inserting special tokens “<cls>” (used for classification) and “<sep>” (used for separation), the BERT input sequence has a length of six. Since zero is the index of the “<cls>” token, encoded_text[:, 0, :] is the BERT representation of the entire input sentence. To evaluate the polysemy token “crane”, we also print out the first three elements of the BERT representation of the token.

```{.python .input}

@tab all

tokens_a = [‘a’, ‘crane’, ‘is’, ‘flying’] encoded_text = get_bert_encoding(net, tokens_a)

Tokens: ‘‘, ‘a’, ‘crane’, ‘is’, ‘flying’, ‘

encoded_text_cls = encoded_text[:, 0, :] encoded_text_crane = encoded_text[:, 2, :] encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]

  1. Now consider a sentence pair
  2. "a crane driver came" and "he just left".
  3. Similarly, `encoded_pair[:, 0, :]` is the encoded result of the entire sentence pair from the pretrained BERT.
  4. Note that the first three elements of the polysemy token "crane" are different from those when the context is different.
  5. This supports that BERT representations are context-sensitive.
  6. ```{.python .input}
  7. #@tab all
  8. tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
  9. encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
  10. # Tokens: '<cls>', 'a', 'crane', 'driver', 'came', '<sep>', 'he', 'just',
  11. # 'left', '<sep>'
  12. encoded_pair_cls = encoded_pair[:, 0, :]
  13. encoded_pair_crane = encoded_pair[:, 2, :]
  14. encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

In :numref:chap_nlp_app, we will fine-tune a pretrained BERT model for downstream natural language processing applications.

Summary

  • The original BERT has two versions, where the base model has 110 million parameters and the large model has 340 million parameters.
  • After pretraining BERT, we can use it to represent single text, text pairs, or any token in them.
  • In the experiment, the same token has different BERT representation when their contexts are different. This supports that BERT representations are context-sensitive.

Exercises

  1. In the experiment, we can see that the masked language modeling loss is significantly higher than the next sentence prediction loss. Why?
  2. Set the maximum length of a BERT input sequence to be 512 (same as the original BERT model). Use the configurations of the original BERT model such as $\text{BERT}_{\text{LARGE}}$. Do you encounter any error when running this section? Why?

:begin_tab:mxnet Discussions :end_tab:

:begin_tab:pytorch Discussions :end_tab: