• 源码分析

    代码目录层级
    5FB08C33-B9D6-43B6-933E-8308A67F38FB.png
    实体识别模型(entity/models.py)省略Albert有关内容

    1. import torch
    2. from torch import nn
    3. import torch.nn.functional as F
    4. from torch.nn import CrossEntropyLoss
    5. from allennlp.nn.util import batched_index_select
    6. from allennlp.modules import FeedForward
    7. from transformers import BertTokenizer, BertPreTrainedModel, BertModel
    8. from transformers import AlbertTokenizer, AlbertPreTrainedModel, AlbertModel
    9. import logging
    10. logger = logging.getLogger('root')
    11. class BertForEntity(BertPreTrainedModel):
    12. def __init__(self, config, num_ner_labels, head_hidden_dim=150, width_embedding_dim=150, max_span_length=8):
    13. super().__init__(config)
    14. self.bert = BertModel(config)
    15. self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
    16. self.width_embedding = nn.Embedding(max_span_length+1, width_embedding_dim)
    17. self.ner_classifier = nn.Sequential(
    18. FeedForward(input_dim=config.hidden_size*2+width_embedding_dim,
    19. num_layers=2,
    20. hidden_dims=head_hidden_dim,
    21. activations=F.relu,
    22. dropout=0.2),
    23. nn.Linear(head_hidden_dim, num_ner_labels)
    24. )
    25. self.init_weights()
    26. def _get_span_embeddings(self, input_ids, spans, token_type_ids=None, attention_mask=None):
    27. sequence_output, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
    28. sequence_output = self.hidden_dropout(sequence_output)
    29. """
    30. spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width
    31. spans_mask: (batch_size, num_spans, )
    32. """
    33. spans_start = spans[:, :, 0].view(spans.size(0), -1)
    34. spans_start_embedding = batched_index_select(sequence_output, spans_start)
    35. spans_end = spans[:, :, 1].view(spans.size(0), -1)
    36. spans_end_embedding = batched_index_select(sequence_output, spans_end)
    37. spans_width = spans[:, :, 2].view(spans.size(0), -1)
    38. spans_width_embedding = self.width_embedding(spans_width)
    39. # Concatenate embeddings of left/right points and the width embedding
    40. spans_embedding = torch.cat((spans_start_embedding, spans_end_embedding, spans_width_embedding), dim=-1)
    41. """
    42. spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim)
    43. """
    44. return spans_embedding
    45. def forward(self, input_ids, spans, spans_mask, spans_ner_label=None, token_type_ids=None, attention_mask=None):
    46. spans_embedding = self._get_span_embeddings(input_ids, spans, token_type_ids=token_type_ids, attention_mask=attention_mask)
    47. ffnn_hidden = []
    48. hidden = spans_embedding
    49. for layer in self.ner_classifier:
    50. hidden = layer(hidden)
    51. ffnn_hidden.append(hidden)
    52. logits = ffnn_hidden[-1]
    53. if spans_ner_label is not None:
    54. loss_fct = CrossEntropyLoss(reduction='sum')
    55. if attention_mask is not None:
    56. active_loss = spans_mask.view(-1) == 1
    57. active_logits = logits.view(-1, logits.shape[-1])
    58. active_labels = torch.where(
    59. active_loss, spans_ner_label.view(-1), torch.tensor(loss_fct.ignore_index).type_as(spans_ner_label)
    60. )
    61. loss = loss_fct(active_logits, active_labels)
    62. else:
    63. loss = loss_fct(logits.view(-1, logits.shape[-1]), spans_ner_label.view(-1))
    64. return loss, logits, spans_embedding
    65. else:
    66. return logits, spans_embedding, spans_embedding

    补充allennlp.nn.util中batched_index_select代码

    1. def batched_index_select(
    2. target: torch.Tensor,
    3. indices: torch.LongTensor,
    4. flattened_indices: Optional[torch.LongTensor] = None,
    5. ) -> torch.Tensor:
    6. """
    7. The given `indices` of size `(batch_size, d_1, ..., d_n)` indexes into the sequence
    8. dimension (dimension 2) of the target, which has size `(batch_size, sequence_length,
    9. embedding_size)`.
    10. This function returns selected values in the target with respect to the provided indices, which
    11. have size `(batch_size, d_1, ..., d_n, embedding_size)`. This can use the optionally
    12. precomputed `flattened_indices` with size `(batch_size * d_1 * ... * d_n)` if given.
    13. An example use case of this function is looking up the start and end indices of spans in a
    14. sequence tensor. This is used in the
    15. [CoreferenceResolver](https://docs.allennlp.org/models/main/models/coref/models/coref/)
    16. model to select contextual word representations corresponding to the start and end indices of
    17. mentions.
    18. The key reason this can't be done with basic torch functions is that we want to be able to use look-up
    19. tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know
    20. a-priori how many spans we are looking up).
    21. # Parameters
    22. target : `torch.Tensor`, required.
    23. A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
    24. This is the tensor to be indexed.
    25. indices : `torch.LongTensor`
    26. A tensor of shape (batch_size, ...), where each element is an index into the
    27. `sequence_length` dimension of the `target` tensor.
    28. flattened_indices : `Optional[torch.Tensor]`, optional (default = `None`)
    29. An optional tensor representing the result of calling `flatten_and_batch_shift_indices`
    30. on `indices`. This is helpful in the case that the indices can be flattened once and
    31. cached for many batch lookups.
    32. # Returns
    33. selected_targets : `torch.Tensor`
    34. A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices
    35. extracted from the batch flattened target tensor.
    36. """
    37. if flattened_indices is None:
    38. # Shape: (batch_size * d_1 * ... * d_n)
    39. flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))
    40. # Shape: (batch_size * sequence_length, embedding_size)
    41. flattened_target = target.view(-1, target.size(-1))
    42. # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    43. flattened_selected = flattened_target.index_select(0, flattened_indices)
    44. selected_shape = list(indices.size()) + [target.size(-1)]
    45. # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    46. selected_targets = flattened_selected.view(*selected_shape)
    47. return selected_targets
    48. def flatten_and_batch_shift_indices(indices: torch.Tensor, sequence_length: int) -> torch.Tensor:
    49. """
    50. This is a subroutine for [`batched_index_select`](./util.md#batched_index_select).
    51. The given `indices` of size `(batch_size, d_1, ..., d_n)` indexes into dimension 2 of a
    52. target tensor, which has size `(batch_size, sequence_length, embedding_size)`. This
    53. function returns a vector that correctly indexes into the flattened target. The sequence
    54. length of the target must be provided to compute the appropriate offsets.
    55. ```python
    56. indices = torch.ones([2,3], dtype=torch.long)
    57. # Sequence length of the target tensor.
    58. sequence_length = 10
    59. shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length)
    60. # Indices into the second element in the batch are correctly shifted
    61. # to take into account that the target tensor will be flattened before
    62. # the indices are applied.
    63. assert shifted_indices == [1, 1, 1, 11, 11, 11]
    1. # Parameters
    2. indices : `torch.LongTensor`, required.
    3. sequence_length : `int`, required.
    4. The length of the sequence the indices index into.
    5. This must be the second dimension of the tensor.
    6. # Returns
    7. offset_indices : `torch.LongTensor`
    8. """
    9. # Shape: (batch_size)
    10. if torch.max(indices) >= sequence_length or torch.min(indices) < 0:
    11. raise ConfigurationError(
    12. f"All elements in indices should be in range (0, {sequence_length - 1})"
    13. )
    14. offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length
    15. for _ in range(len(indices.size()) - 1):
    16. offsets = offsets.unsqueeze(1)
    17. # Shape: (batch_size, d_1, ..., d_n)
    18. offset_indices = indices + offsets
    19. # Shape: (batch_size * d_1 * ... * d_n)
    20. offset_indices = offset_indices.view(-1)
    21. return offset_indices
    EntityModel
    ```python
    class EntityModel():
    
        def __init__(self, args, num_ner_labels):
            super().__init__()
    
            bert_model_name = args.model
            vocab_name = bert_model_name
    
            if args.bert_model_dir is not None:
                bert_model_name = str(args.bert_model_dir) + '/'
                # vocab_name = bert_model_name + 'vocab.txt'
                vocab_name = bert_model_name
                logger.info('Loading BERT model from {}'.format(bert_model_name))
    
            if args.use_albert:
                self.tokenizer = AlbertTokenizer.from_pretrained(vocab_name)
                self.bert_model = AlbertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)
            else:
                self.tokenizer = BertTokenizer.from_pretrained(vocab_name)
                self.bert_model = BertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)
    
            self._model_device = 'cpu'
            self.move_model_to_cuda()
    
        def move_model_to_cuda(self):
            if not torch.cuda.is_available():
                logger.error('No CUDA found!')
                exit(-1)
            logger.info('Moving to CUDA...')
            self._model_device = 'cuda'
            self.bert_model.cuda()
            logger.info('# GPUs = %d'%(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                self.bert_model = torch.nn.DataParallel(self.bert_model)
    
        def _get_input_tensors(self, tokens, spans, spans_ner_label):
            start2idx = []
            end2idx = []
    
            bert_tokens = []
            bert_tokens.append(self.tokenizer.cls_token)
            for token in tokens:
                start2idx.append(len(bert_tokens))
                sub_tokens = self.tokenizer.tokenize(token)
                bert_tokens += sub_tokens
                end2idx.append(len(bert_tokens)-1)
            bert_tokens.append(self.tokenizer.sep_token)
    
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(bert_tokens)
            tokens_tensor = torch.tensor([indexed_tokens])
    
            bert_spans = [[start2idx[span[0]], end2idx[span[1]], span[2]] for span in spans]
            bert_spans_tensor = torch.tensor([bert_spans])
    
            spans_ner_label_tensor = torch.tensor([spans_ner_label])
    
            return tokens_tensor, bert_spans_tensor, spans_ner_label_tensor
    
        def _get_input_tensors_batch(self, samples_list, training=True):
            tokens_tensor_list = []
            bert_spans_tensor_list = []
            spans_ner_label_tensor_list = []
            sentence_length = []
    
            max_tokens = 0
            max_spans = 0
            for sample in samples_list:
                tokens = sample['tokens']
                spans = sample['spans']
                spans_ner_label = sample['spans_label']
    
                tokens_tensor, bert_spans_tensor, spans_ner_label_tensor = self._get_input_tensors(tokens, spans, spans_ner_label)
                tokens_tensor_list.append(tokens_tensor)
                bert_spans_tensor_list.append(bert_spans_tensor)
                spans_ner_label_tensor_list.append(spans_ner_label_tensor)
                assert(bert_spans_tensor.shape[1] == spans_ner_label_tensor.shape[1])
                if (tokens_tensor.shape[1] > max_tokens):
                    max_tokens = tokens_tensor.shape[1]
                if (bert_spans_tensor.shape[1] > max_spans):
                    max_spans = bert_spans_tensor.shape[1]
                sentence_length.append(sample['sent_length'])
            sentence_length = torch.Tensor(sentence_length)
    
            # apply padding and concatenate tensors
            final_tokens_tensor = None
            final_attention_mask = None
            final_bert_spans_tensor = None
            final_spans_ner_label_tensor = None
            final_spans_mask_tensor = None
            for tokens_tensor, bert_spans_tensor, spans_ner_label_tensor in zip(tokens_tensor_list, bert_spans_tensor_list, spans_ner_label_tensor_list):
                # padding for tokens
                num_tokens = tokens_tensor.shape[1]
                tokens_pad_length = max_tokens - num_tokens
                attention_tensor = torch.full([1,num_tokens], 1, dtype=torch.long)
                if tokens_pad_length>0:
                    pad = torch.full([1,tokens_pad_length], self.tokenizer.pad_token_id, dtype=torch.long)
                    tokens_tensor = torch.cat((tokens_tensor, pad), dim=1)
                    attention_pad = torch.full([1,tokens_pad_length], 0, dtype=torch.long)
                    attention_tensor = torch.cat((attention_tensor, attention_pad), dim=1)
    
                # padding for spans
                num_spans = bert_spans_tensor.shape[1]
                spans_pad_length = max_spans - num_spans
                spans_mask_tensor = torch.full([1,num_spans], 1, dtype=torch.long)
                if spans_pad_length>0:
                    pad = torch.full([1,spans_pad_length,bert_spans_tensor.shape[2]], 0, dtype=torch.long)
                    bert_spans_tensor = torch.cat((bert_spans_tensor, pad), dim=1)
                    mask_pad = torch.full([1,spans_pad_length], 0, dtype=torch.long)
                    spans_mask_tensor = torch.cat((spans_mask_tensor, mask_pad), dim=1)
                    spans_ner_label_tensor = torch.cat((spans_ner_label_tensor, mask_pad), dim=1)
    
                # update final outputs
                if final_tokens_tensor is None:
                    final_tokens_tensor = tokens_tensor
                    final_attention_mask = attention_tensor
                    final_bert_spans_tensor = bert_spans_tensor
                    final_spans_ner_label_tensor = spans_ner_label_tensor
                    final_spans_mask_tensor = spans_mask_tensor
                else:
                    final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0)
                    final_attention_mask = torch.cat((final_attention_mask, attention_tensor), dim=0)
                    final_bert_spans_tensor = torch.cat((final_bert_spans_tensor, bert_spans_tensor), dim=0)
                    final_spans_ner_label_tensor = torch.cat((final_spans_ner_label_tensor, spans_ner_label_tensor), dim=0)
                    final_spans_mask_tensor = torch.cat((final_spans_mask_tensor, spans_mask_tensor), dim=0)
            #logger.info(final_tokens_tensor)
            #logger.info(final_attention_mask)
            #logger.info(final_bert_spans_tensor)
            #logger.info(final_bert_spans_tensor.shape)
            #logger.info(final_spans_mask_tensor.shape)
            #logger.info(final_spans_ner_label_tensor.shape)
            return final_tokens_tensor, final_attention_mask, final_bert_spans_tensor, final_spans_mask_tensor, final_spans_ner_label_tensor, sentence_length
    
        def run_batch(self, samples_list, try_cuda=True, training=True):
            # convert samples to input tensors
            tokens_tensor, attention_mask_tensor, bert_spans_tensor, spans_mask_tensor, spans_ner_label_tensor, sentence_length = self._get_input_tensors_batch(samples_list, training)
    
            output_dict = {
                'ner_loss': 0,
            }
    
            if training:
                self.bert_model.train()
                ner_loss, ner_logits, spans_embedding = self.bert_model(
                    input_ids = tokens_tensor.to(self._model_device),
                    spans = bert_spans_tensor.to(self._model_device),
                    spans_mask = spans_mask_tensor.to(self._model_device),
                    spans_ner_label = spans_ner_label_tensor.to(self._model_device),
                    attention_mask = attention_mask_tensor.to(self._model_device),
                )
                output_dict['ner_loss'] = ner_loss.sum()
                output_dict['ner_llh'] = F.log_softmax(ner_logits, dim=-1)
            else:
                self.bert_model.eval()
                with torch.no_grad():
                    ner_logits, spans_embedding, last_hidden = self.bert_model(
                        input_ids = tokens_tensor.to(self._model_device),
                        spans = bert_spans_tensor.to(self._model_device),
                        spans_mask = spans_mask_tensor.to(self._model_device),
                        spans_ner_label = None,
                        attention_mask = attention_mask_tensor.to(self._model_device),
                    )
                _, predicted_label = ner_logits.max(2)
                predicted_label = predicted_label.cpu().numpy()
                last_hidden = last_hidden.cpu().numpy()
    
                predicted = []
                pred_prob = []
                hidden = []
                for i, sample in enumerate(samples_list):
                    ner = []
                    prob = []
                    lh = []
                    for j in range(len(sample['spans'])):
                        ner.append(predicted_label[i][j])
                        # prob.append(F.softmax(ner_logits[i][j], dim=-1).cpu().numpy())
                        prob.append(ner_logits[i][j].cpu().numpy())
                        lh.append(last_hidden[i][j])
                    predicted.append(ner)
                    pred_prob.append(prob)
                    hidden.append(lh)
                output_dict['pred_ner'] = predicted
                output_dict['ner_probs'] = pred_prob
                output_dict['ner_last_hidden'] = hidden
    
            return output_dict