Natural Language Inference: Fine-Tuning BERT

:label:sec_natural-language-inference-bert

In earlier sections of this chapter, we have designed an attention-based architecture (in :numref:sec_natural-language-inference-attention) for the natural language inference task on the SNLI dataset (as described in :numref:sec_natural-language-inference-and-dataset). Now we revisit this task by fine-tuning BERT. As discussed in :numref:sec_finetuning-bert, natural language inference is a sequence-level text pair classification problem, and fine-tuning BERT only requires an additional MLP-based architecture, as illustrated in :numref:fig_nlp-map-nli-bert.

This section feeds pretrained BERT to an MLP-based architecture for natural language inference. :label:fig_nlp-map-nli-bert

In this section, we will download a pretrained small version of BERT, then fine-tune it for natural language inference on the SNLI dataset.

```{.python .input} from d2l import mxnet as d2l import json import multiprocessing from mxnet import gluon, np, npx from mxnet.gluon import nn import os

npx.set_np()

  1. ```{.python .input}
  2. #@tab pytorch
  3. from d2l import torch as d2l
  4. import json
  5. import multiprocessing
  6. import torch
  7. from torch import nn
  8. import os

Loading Pretrained BERT

We have explained how to pretrain BERT on the WikiText-2 dataset in :numref:sec_bert-dataset and :numref:sec_bert-pretraining (note that the original BERT model is pretrained on much bigger corpora). As discussed in :numref:sec_bert-pretraining, the original BERT model has hundreds of millions of parameters. In the following, we provide two versions of pretrained BERT: “bert.base” is about as big as the original BERT base model that requires a lot of computational resources to fine-tune, while “bert.small” is a small version to facilitate demonstration.

```{.python .input} d2l.DATA_HUB[‘bert.base’] = (d2l.DATA_URL + ‘bert.base.zip’, ‘7b3820b35da691042e5d34c0971ac3edbd80d3f4’) d2l.DATA_HUB[‘bert.small’] = (d2l.DATA_URL + ‘bert.small.zip’, ‘a4e718a47137ccd1809c9107ab4f5edd317bae2c’)

  1. ```{.python .input}
  2. #@tab pytorch
  3. d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
  4. '225d66f04cae318b841a13d32af3acc165f253ac')
  5. d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
  6. 'c72329e68a732bef0452e4b96a1c341c8910f81f')

Either pretrained BERT model contains a “vocab.json” file that defines the vocabulary set and a “pretrained.params” file of the pretrained parameters. We implement the following load_pretrained_model function to load pretrained BERT parameters.

```{.python .input} def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout, max_len, devices): data_dir = d2l.download_extract(pretrained_model)

  1. # Define an empty vocabulary to load the predefined vocabulary
  2. vocab = d2l.Vocab()
  3. vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  4. vocab.token_to_idx = {token: idx for idx, token in enumerate(
  5. vocab.idx_to_token)}
  6. bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
  7. num_layers, dropout, max_len)
  8. # Load pretrained BERT parameters
  9. bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
  10. ctx=devices)
  11. return bert, vocab
  1. ```{.python .input}
  2. #@tab pytorch
  3. def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
  4. num_heads, num_layers, dropout, max_len, devices):
  5. data_dir = d2l.download_extract(pretrained_model)
  6. # Define an empty vocabulary to load the predefined vocabulary
  7. vocab = d2l.Vocab()
  8. vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  9. vocab.token_to_idx = {token: idx for idx, token in enumerate(
  10. vocab.idx_to_token)}
  11. bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
  12. ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
  13. num_heads=4, num_layers=2, dropout=0.2,
  14. max_len=max_len, key_size=256, query_size=256,
  15. value_size=256, hid_in_features=256,
  16. mlm_in_features=256, nsp_in_features=256)
  17. # Load pretrained BERT parameters
  18. bert.load_state_dict(torch.load(os.path.join(data_dir,
  19. 'pretrained.params')))
  20. return bert, vocab

To facilitate demonstration on most of machines, we will load and fine-tune the small version (“bert.small”) of the pretrained BERT in this section. In the exercise, we will show how to fine-tune the much larger “bert.base” to significantly improve the testing accuracy.

```{.python .input}

@tab all

devices = d2l.try_all_gpus() bert, vocab = load_pretrained_model( ‘bert.small’, num_hiddens=256, ffn_num_hiddens=512, num_heads=4, num_layers=2, dropout=0.1, max_len=512, devices=devices)

  1. ## The Dataset for Fine-Tuning BERT
  2. For the downstream task natural language inference on the SNLI dataset,
  3. we define a customized dataset class `SNLIBERTDataset`.
  4. In each example,
  5. the premise and hypothesis form a pair of text sequence
  6. and is packed into one BERT input sequence as depicted in :numref:`fig_bert-two-seqs`.
  7. Recall :numref:`subsec_bert_input_rep` that segment IDs
  8. are used to distinguish the premise and the hypothesis in a BERT input sequence.
  9. With the predefined maximum length of a BERT input sequence (`max_len`),
  10. the last token of the longer of the input text pair keeps getting removed until
  11. `max_len` is met.
  12. To accelerate generation of the SNLI dataset
  13. for fine-tuning BERT,
  14. we use 4 worker processes to generate training or testing examples in parallel.
  15. ```{.python .input}
  16. class SNLIBERTDataset(gluon.data.Dataset):
  17. def __init__(self, dataset, max_len, vocab=None):
  18. all_premise_hypothesis_tokens = [[
  19. p_tokens, h_tokens] for p_tokens, h_tokens in zip(
  20. *[d2l.tokenize([s.lower() for s in sentences])
  21. for sentences in dataset[:2]])]
  22. self.labels = np.array(dataset[2])
  23. self.vocab = vocab
  24. self.max_len = max_len
  25. (self.all_token_ids, self.all_segments,
  26. self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
  27. print('read ' + str(len(self.all_token_ids)) + ' examples')
  28. def _preprocess(self, all_premise_hypothesis_tokens):
  29. pool = multiprocessing.Pool(4) # Use 4 worker processes
  30. out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
  31. all_token_ids = [
  32. token_ids for token_ids, segments, valid_len in out]
  33. all_segments = [segments for token_ids, segments, valid_len in out]
  34. valid_lens = [valid_len for token_ids, segments, valid_len in out]
  35. return (np.array(all_token_ids, dtype='int32'),
  36. np.array(all_segments, dtype='int32'),
  37. np.array(valid_lens))
  38. def _mp_worker(self, premise_hypothesis_tokens):
  39. p_tokens, h_tokens = premise_hypothesis_tokens
  40. self._truncate_pair_of_tokens(p_tokens, h_tokens)
  41. tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
  42. token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
  43. * (self.max_len - len(tokens))
  44. segments = segments + [0] * (self.max_len - len(segments))
  45. valid_len = len(tokens)
  46. return token_ids, segments, valid_len
  47. def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
  48. # Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
  49. # input
  50. while len(p_tokens) + len(h_tokens) > self.max_len - 3:
  51. if len(p_tokens) > len(h_tokens):
  52. p_tokens.pop()
  53. else:
  54. h_tokens.pop()
  55. def __getitem__(self, idx):
  56. return (self.all_token_ids[idx], self.all_segments[idx],
  57. self.valid_lens[idx]), self.labels[idx]
  58. def __len__(self):
  59. return len(self.all_token_ids)

```{.python .input}

@tab pytorch

class SNLIBERTDataset(torch.utils.data.Dataset): def init(self, dataset, max_len, vocab=None): all_premise_hypothesis_tokens = [[ p_tokens, h_tokens] for p_tokens, h_tokens in zip( *[d2l.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])]

  1. self.labels = torch.tensor(dataset[2])
  2. self.vocab = vocab
  3. self.max_len = max_len
  4. (self.all_token_ids, self.all_segments,
  5. self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
  6. print('read ' + str(len(self.all_token_ids)) + ' examples')
  7. def _preprocess(self, all_premise_hypothesis_tokens):
  8. pool = multiprocessing.Pool(4) # Use 4 worker processes
  9. out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
  10. all_token_ids = [
  11. token_ids for token_ids, segments, valid_len in out]
  12. all_segments = [segments for token_ids, segments, valid_len in out]
  13. valid_lens = [valid_len for token_ids, segments, valid_len in out]
  14. return (torch.tensor(all_token_ids, dtype=torch.long),
  15. torch.tensor(all_segments, dtype=torch.long),
  16. torch.tensor(valid_lens))
  17. def _mp_worker(self, premise_hypothesis_tokens):
  18. p_tokens, h_tokens = premise_hypothesis_tokens
  19. self._truncate_pair_of_tokens(p_tokens, h_tokens)
  20. tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
  21. token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
  22. * (self.max_len - len(tokens))
  23. segments = segments + [0] * (self.max_len - len(segments))
  24. valid_len = len(tokens)
  25. return token_ids, segments, valid_len
  26. def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
  27. # Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
  28. # input
  29. while len(p_tokens) + len(h_tokens) > self.max_len - 3:
  30. if len(p_tokens) > len(h_tokens):
  31. p_tokens.pop()
  32. else:
  33. h_tokens.pop()
  34. def __getitem__(self, idx):
  35. return (self.all_token_ids[idx], self.all_segments[idx],
  36. self.valid_lens[idx]), self.labels[idx]
  37. def __len__(self):
  38. return len(self.all_token_ids)
  1. After downloading the SNLI dataset,
  2. we generate training and testing examples
  3. by instantiating the `SNLIBERTDataset` class.
  4. Such examples will be read in minibatches during training and testing
  5. of natural language inference.
  6. ```{.python .input}
  7. # Reduce `batch_size` if there is an out of memory error. In the original BERT
  8. # model, `max_len` = 512
  9. batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
  10. data_dir = d2l.download_extract('SNLI')
  11. train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
  12. test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
  13. train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
  14. num_workers=num_workers)
  15. test_iter = gluon.data.DataLoader(test_set, batch_size,
  16. num_workers=num_workers)

```{.python .input}

@tab pytorch

Reduce batch_size if there is an out of memory error. In the original BERT

model, max_len = 512

batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers() data_dir = d2l.download_extract(‘SNLI’) train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab) test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(test_set, batch_size, num_workers=num_workers)

  1. ## Fine-Tuning BERT
  2. As :numref:`fig_bert-two-seqs` indicates,
  3. fine-tuning BERT for natural language inference
  4. requires only an extra MLP consisting of two fully connected layers
  5. (see `self.hidden` and `self.output` in the following `BERTClassifier` class).
  6. This MLP transforms the
  7. BERT representation of the special “&lt;cls&gt;” token,
  8. which encodes the information of both the premise and the hypothesis,
  9. into three outputs of natural language inference:
  10. entailment, contradiction, and neutral.
  11. ```{.python .input}
  12. class BERTClassifier(nn.Block):
  13. def __init__(self, bert):
  14. super(BERTClassifier, self).__init__()
  15. self.encoder = bert.encoder
  16. self.hidden = bert.hidden
  17. self.output = nn.Dense(3)
  18. def forward(self, inputs):
  19. tokens_X, segments_X, valid_lens_x = inputs
  20. encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
  21. return self.output(self.hidden(encoded_X[:, 0, :]))

```{.python .input}

@tab pytorch

class BERTClassifier(nn.Module): def init(self, bert): super(BERTClassifier, self).init() self.encoder = bert.encoder self.hidden = bert.hidden self.output = nn.Linear(256, 3)

  1. def forward(self, inputs):
  2. tokens_X, segments_X, valid_lens_x = inputs
  3. encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
  4. return self.output(self.hidden(encoded_X[:, 0, :]))
  1. In the following,
  2. the pretrained BERT model `bert` is fed into the `BERTClassifier` instance `net` for
  3. the downstream application.
  4. In common implementations of BERT fine-tuning,
  5. only the parameters of the output layer of the additional MLP (`net.output`) will be learned from scratch.
  6. All the parameters of the pretrained BERT encoder (`net.encoder`) and the hidden layer of the additional MLP (`net.hidden`) will be fine-tuned.
  7. ```{.python .input}
  8. net = BERTClassifier(bert)
  9. net.output.initialize(ctx=devices)

```{.python .input}

@tab pytorch

net = BERTClassifier(bert)

  1. Recall that
  2. in :numref:`sec_bert`
  3. both the `MaskLM` class and the `NextSentencePred` class
  4. have parameters in their employed MLPs.
  5. These parameters are part of those in the pretrained BERT model
  6. `bert`, and thus part of parameters in `net`.
  7. However, such parameters are only for computing
  8. the masked language modeling loss
  9. and the next sentence prediction loss
  10. during pretraining.
  11. These two loss functions are irrelevant to fine-tuning downstream applications,
  12. thus the parameters of the employed MLPs in
  13. `MaskLM` and `NextSentencePred` are not updated (staled) when BERT is fine-tuned.
  14. To allow parameters with stale gradients,
  15. the flag `ignore_stale_grad=True` is set in the `step` function of `d2l.train_batch_ch13`.
  16. We use this function to train and evaluate the model `net` using the training set
  17. (`train_iter`) and the testing set (`test_iter`) of SNLI.
  18. Due to the limited computational resources, the training and testing accuracy
  19. can be further improved: we leave its discussions in the exercises.
  20. ```{.python .input}
  21. lr, num_epochs = 1e-4, 5
  22. trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
  23. loss = gluon.loss.SoftmaxCrossEntropyLoss()
  24. d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
  25. d2l.split_batch_multi_inputs)

```{.python .input}

@tab pytorch

lr, num_epochs = 1e-4, 5 trainer = torch.optim.Adam(net.parameters(), lr=lr) loss = nn.CrossEntropyLoss(reduction=’none’) d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices) ```

Summary

  • We can fine-tune the pretrained BERT model for downstream applications, such as natural language inference on the SNLI dataset.
  • During fine-tuning, the BERT model becomes part of the model for the downstream application. Parameters that are only related to pretraining loss will not be updated during fine-tuning.

Exercises

  1. Fine-tune a much larger pretrained BERT model that is about as big as the original BERT base model if your computational resource allows. Set arguments in the load_pretrained_model function as: replacing ‘bert.small’ with ‘bert.base’, increasing values of num_hiddens=256, ffn_num_hiddens=512, num_heads=4, and num_layers=2 to 768, 3072, 12, and 12, respectively. By increasing fine-tuning epochs (and possibly tuning other hyperparameters), can you get a testing accuracy higher than 0.86?
  2. How to truncate a pair of sequences according to their ratio of length? Compare this pair truncation method and the one used in the SNLIBERTDataset class. What are their pros and cons?

:begin_tab:mxnet Discussions :end_tab:

:begin_tab:pytorch Discussions :end_tab: