• 数据格式

    单条JSON数据格式如下:

    1. {
    2. "tokens": [
    3. "Newspaper"
    4. //...
    5. ],
    6. "entities": [
    7. {
    8. "type": "Loc",
    9. "start": 4, //Entity提取自tokens[start:end]
    10. "end": 5
    11. }
    12. //...
    13. ],
    14. "relations": [
    15. {
    16. "type": "OrgBased_In",
    17. "head": 2, //对应头实体的下标
    18. "tail": 1
    19. }
    20. ],
    21. "orig_id": 3255
    22. }
    • 处理JSON数据(input_reader.py)

    每一组JSON数据映射到一个Document。一个Document由Token、Entity、Relation组成。

    1. def _parse_document(self, doc, dataset) -> Document:
    2. jtokens = doc['tokens']
    3. jrelations = doc['relations']
    4. jentities = doc['entities']
    5. # parse tokens
    6. doc_tokens, doc_encoding = self._parse_tokens(jtokens, dataset)
    7. # parse entity mentions
    8. entities = self._parse_entities(jentities, doc_tokens, dataset)
    9. # parse relations
    10. relations = self._parse_relations(jrelations, entities, dataset)
    11. # create document
    12. document = dataset.create_document(doc_tokens, entities, relations, doc_encoding)
    13. return document
    • 类的定义(entities.py)

      1. class Token:
      2. def __init__(self, tid: int, index: int, span_start: int, span_end: int, phrase: str):
      3. self._tid = tid # ID within the corresponding dataset
      4. self._index = index # original token index in document
      5. self._span_start = span_start # start of token span in document (inclusive)
      6. self._span_end = span_end # end of token span in document (exclusive)
      7. self._phrase = phrase

      一个Entity可能由多个Token组成。

      1. class Entity:
      2. def __init__(self, eid: int, entity_type: EntityType, tokens: List[Token], phrase: str):
      3. self._eid = eid # ID within the corresponding dataset
      4. self._entity_type = entity_type
      5. self._tokens = tokens
      6. self._phrase = phrase
      7. @property
      8. def tokens(self):
      9. return TokenSpan(self._tokens)

      Relation中包含头实体和尾实体,考虑了关系的对称性。

      1. class Relation:
      2. def __init__(self, rid: int, relation_type: RelationType, head_entity: Entity,
      3. tail_entity: Entity, reverse: bool = False):
      4. self._rid = rid # ID within the corresponding dataset
      5. self._relation_type = relation_type
      6. self._head_entity = head_entity
      7. self._tail_entity = tail_entity
      8. self._reverse = reverse
      9. self._first_entity = head_entity if not reverse else tail_entity
      10. self._second_entity = tail_entity if not reverse else head_entity

      TokenSpan类是为了便于获得Entity的Span,内置属性为一组Tokens,返回Span是第一组Token的start和最后一组Token的end。

      1. class TokenSpan:
      2. def __init__(self, tokens):
      3. self._tokens = tokens
      4. @property
      5. def span_start(self):
      6. return self._tokens[0].span_start
      7. @property
      8. def span_end(self):
      9. return self._tokens[-1].span_end
      10. @property
      11. def span(self):
      12. return self.span_start, self.span_end
    • 解析Token、Entity和Relation(input_reader.py)

      1. def _parse_tokens(self, jtokens, dataset):
      2. doc_tokens = []
      3. # full document encoding including special tokens ([CLS] and [SEP]) and byte-pair encodings of original tokens
      4. doc_encoding = [self._tokenizer.convert_tokens_to_ids('[CLS]')]
      5. # parse tokens
      6. for i, token_phrase in enumerate(jtokens):
      7. token_encoding = self._tokenizer.encode(token_phrase, add_special_tokens=False)
      8. span_start, span_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding))
      9. token = dataset.create_token(i, span_start, span_end, token_phrase)
      10. doc_tokens.append(token)
      11. doc_encoding += token_encoding
      12. doc_encoding += [self._tokenizer.convert_tokens_to_ids('[SEP]')]
      13. return doc_tokens, doc_encoding
      14. def _parse_entities(self, jentities, doc_tokens, dataset) -> List[Entity]:
      15. entities = []
      16. for entity_idx, jentity in enumerate(jentities):
      17. entity_type = self._entity_types[jentity['type']]
      18. start, end = jentity['start'], jentity['end']
      19. # create entity mention
      20. tokens = doc_tokens[start:end]
      21. phrase = " ".join([t.phrase for t in tokens])
      22. entity = dataset.create_entity(entity_type, tokens, phrase)
      23. entities.append(entity)
      24. return entities
      25. def _parse_relations(self, jrelations, entities, dataset) -> List[Relation]:
      26. relations = []
      27. for jrelation in jrelations:
      28. relation_type = self._relation_types[jrelation['type']]
      29. head_idx = jrelation['head']
      30. tail_idx = jrelation['tail']
      31. # create relation
      32. head = entities[head_idx]
      33. tail = entities[tail_idx]
      34. reverse = int(tail.tokens[0].index) < int(head.tokens[0].index)
      35. # for symmetric relations: head occurs before tail in sentence
      36. if relation_type.symmetric and reverse:
      37. head, tail = util.swap(head, tail)
      38. relation = dataset.create_relation(relation_type, head_entity=head, tail_entity=tail, reverse=reverse)
      39. relations.append(relation)
      40. return relations
    • 创建Dataset(entities.py)

      1. class Dataset(TorchDataset):
      2. TRAIN_MODE = 'train'
      3. EVAL_MODE = 'eval'
      4. def __init__(self, label, rel_types, entity_types, neg_entity_count,
      5. neg_rel_count, max_span_size):
      6. self._label = label
      7. self._rel_types = rel_types
      8. self._entity_types = entity_types
      9. self._neg_entity_count = neg_entity_count
      10. self._neg_rel_count = neg_rel_count
      11. self._max_span_size = max_span_size
      12. self._mode = Dataset.TRAIN_MODE
      13. self._documents = OrderedDict()
      14. self._entities = OrderedDict()
      15. self._relations = OrderedDict()
      16. # current ids
      17. self._doc_id = 0
      18. self._rid = 0
      19. self._eid = 0
      20. self._tid = 0
      21. def __getitem__(self, index: int):
      22. doc = self._documents[index]
      23. if self._mode == Dataset.TRAIN_MODE:
      24. return sampling.create_train_sample(doc, self._neg_entity_count, self._neg_rel_count,
      25. self._max_span_size, len(self._rel_types))
      26. else:
      27. return sampling.create_eval_sample(doc, self._max_span_size)
    • 采样(Sampling.py)

    创建实体mask。

    def create_entity_mask(start, end, context_size):
        mask = torch.zeros(context_size, dtype=torch.bool)
        mask[start:end] = 1
        return mask
    

    创建关系mask,将首位实体之间的context置为1。

    def create_rel_mask(s1, s2, context_size):
        start = s1[1] if s1[1] < s2[0] else s2[1]
        end = s2[0] if s1[1] < s2[0] else s1[0]
        mask = create_entity_mask(start, end, context_size)
        return mask
    

    创建训练样本
    特色点为负采样,选择的是实体,但不存在关系。没有关系的实体关系置为0。

    def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span_size: int, rel_type_count: int):
        encodings = doc.encoding
        token_count = len(doc.tokens)
        context_size = len(encodings)
    
        # positive entities
        pos_entity_spans, pos_entity_types, pos_entity_masks, pos_entity_sizes = [], [], [], []
        for e in doc.entities:
            pos_entity_spans.append(e.span)
            pos_entity_types.append(e.entity_type.index)
            pos_entity_masks.append(create_entity_mask(*e.span, context_size))
            pos_entity_sizes.append(len(e.tokens))
    
        # positive relations
        pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []
        for rel in doc.relations:
            s1, s2 = rel.head_entity.span, rel.tail_entity.span
            pos_rels.append((pos_entity_spans.index(s1), pos_entity_spans.index(s2)))
            pos_rel_spans.append((s1, s2))
            pos_rel_types.append(rel.relation_type)
            pos_rel_masks.append(create_rel_mask(s1, s2, context_size))
    
        # negative entities
        neg_entity_spans, neg_entity_sizes = [], []
        for size in range(1, max_span_size + 1):
            for i in range(0, (token_count - size) + 1):
                span = doc.tokens[i:i + size].span
                if span not in pos_entity_spans:
                    neg_entity_spans.append(span)
                    neg_entity_sizes.append(size)
    
        # sample negative entities
        neg_entity_samples = random.sample(list(zip(neg_entity_spans, neg_entity_sizes)),
                                           min(len(neg_entity_spans), neg_entity_count))
        neg_entity_spans, neg_entity_sizes = zip(*neg_entity_samples) if neg_entity_samples else ([], [])
    
        neg_entity_masks = [create_entity_mask(*span, context_size) for span in neg_entity_spans]
        neg_entity_types = [0] * len(neg_entity_spans)
    
        # negative relations
        # use only strong negative relations, i.e. pairs of actual (labeled) entities that are not related
        neg_rel_spans = []
    
        for i1, s1 in enumerate(pos_entity_spans):
            for i2, s2 in enumerate(pos_entity_spans):
                rev = (s2, s1)
                rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric
    
                # do not add as negative relation sample:
                # neg. relations from an entity to itself
                # entity pairs that are related according to gt
                # entity pairs whose reverse exists as a symmetric relation in gt
                if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric:
                    neg_rel_spans.append((s1, s2))
    
        # sample negative relations
        neg_rel_spans = random.sample(neg_rel_spans, min(len(neg_rel_spans), neg_rel_count))
    
        neg_rels = [(pos_entity_spans.index(s1), pos_entity_spans.index(s2)) for s1, s2 in neg_rel_spans]
        neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans]
        neg_rel_types = [0] * len(neg_rel_spans)
    
        # merge
        entity_types = pos_entity_types + neg_entity_types
        entity_masks = pos_entity_masks + neg_entity_masks
        entity_sizes = pos_entity_sizes + list(neg_entity_sizes)
    
        rels = pos_rels + neg_rels
        rel_types = [r.index for r in pos_rel_types] + neg_rel_types
        rel_masks = pos_rel_masks + neg_rel_masks
    
        assert len(entity_masks) == len(entity_sizes) == len(entity_types)
        assert len(rels) == len(rel_masks) == len(rel_types)
    
        # create tensors
        # token indices
        encodings = torch.tensor(encodings, dtype=torch.long)
    
        # masking of tokens
        context_masks = torch.ones(context_size, dtype=torch.bool)
    
        # also create samples_masks:
        # tensors to mask entity/relation samples of batch
        # since samples are stacked into batches, "padding" entities/relations possibly must be created
        # these are later masked during loss computation
        if entity_masks:
            entity_types = torch.tensor(entity_types, dtype=torch.long)
            entity_masks = torch.stack(entity_masks)
            entity_sizes = torch.tensor(entity_sizes, dtype=torch.long)
            entity_sample_masks = torch.ones([entity_masks.shape[0]], dtype=torch.bool)
        else:
            # corner case handling (no pos/neg entities)
            entity_types = torch.zeros([1], dtype=torch.long)
            entity_masks = torch.zeros([1, context_size], dtype=torch.bool)
            entity_sizes = torch.zeros([1], dtype=torch.long)
            entity_sample_masks = torch.zeros([1], dtype=torch.bool)
    
        if rels:
            rels = torch.tensor(rels, dtype=torch.long)
            rel_masks = torch.stack(rel_masks)
            rel_types = torch.tensor(rel_types, dtype=torch.long)
            rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool)
        else:
            # corner case handling (no pos/neg relations)
            rels = torch.zeros([1, 2], dtype=torch.long)
            rel_types = torch.zeros([1], dtype=torch.long)
            rel_masks = torch.zeros([1, context_size], dtype=torch.bool)
            rel_sample_masks = torch.zeros([1], dtype=torch.bool)
    
        # relation types to one-hot encoding
        rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
        rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
        rel_types_onehot = rel_types_onehot[:, 1:]  # all zeros for 'none' relation
    
        return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
                    entity_sizes=entity_sizes, entity_types=entity_types,
                    rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
                    entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)
    

    创建验证样本

    def create_eval_sample(doc, max_span_size: int):
        encodings = doc.encoding
        token_count = len(doc.tokens)
        context_size = len(encodings)
    
        # create entity candidates
        entity_spans = []
        entity_masks = []
        entity_sizes = []
    
        for size in range(1, max_span_size + 1):
            for i in range(0, (token_count - size) + 1):
                span = doc.tokens[i:i + size].span
                entity_spans.append(span)
                entity_masks.append(create_entity_mask(*span, context_size))
                entity_sizes.append(size)
    
        # create tensors
        # token indices
        _encoding = encodings
        encodings = torch.zeros(context_size, dtype=torch.long)
        encodings[:len(_encoding)] = torch.tensor(_encoding, dtype=torch.long)
    
        # masking of tokens
        context_masks = torch.zeros(context_size, dtype=torch.bool)
        context_masks[:len(_encoding)] = 1
    
        # entities
        if entity_masks:
            entity_masks = torch.stack(entity_masks)
            entity_sizes = torch.tensor(entity_sizes, dtype=torch.long)
            entity_spans = torch.tensor(entity_spans, dtype=torch.long)
    
            # tensors to mask entity samples of batch
            # since samples are stacked into batches, "padding" entities possibly must be created
            # these are later masked during evaluation
            entity_sample_masks = torch.tensor([1] * entity_masks.shape[0], dtype=torch.bool)
        else:
            # corner case handling (no entities)
            entity_masks = torch.zeros([1, context_size], dtype=torch.bool)
            entity_sizes = torch.zeros([1], dtype=torch.long)
            entity_spans = torch.zeros([1, 2], dtype=torch.long)
            entity_sample_masks = torch.zeros([1], dtype=torch.bool)
    
        return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
                    entity_sizes=entity_sizes, entity_spans=entity_spans, entity_sample_masks=entity_sample_masks)
    

    每次返回的数据类型为dict,实际上每个batch中的数据维度是不一样的,所以还需要自定义collate_fn使得每个batch内的数据填充到该batch的最大长度,使得数据维度相同。

    def collate_fn_padding(batch):
        padded_batch = dict()
        keys = batch[0].keys()
    
        for key in keys:
            samples = [s[key] for s in batch]
    
            if not batch[0][key].shape:
                padded_batch[key] = torch.stack(samples)
            else:
                padded_batch[key] = util.padded_stack([s[key] for s in batch])
    
        return padded_batch
    

    util中的padding_stack:

    def padded_stack(tensors, padding=0):
        dim_count = len(tensors[0].shape)
    
        max_shape = [max([t.shape[d] for t in tensors]) for d in range(dim_count)]
        padded_tensors = []
    
        for t in tensors:
            e = extend_tensor(t, max_shape, fill=padding)
            padded_tensors.append(e)
    
        stacked = torch.stack(padded_tensors)
        return stacked
    def extend_tensor(tensor, extended_shape, fill=0):
        tensor_shape = tensor.shape
    
        extended_tensor = torch.zeros(extended_shape, dtype=tensor.dtype).to(tensor.device)
        extended_tensor = extended_tensor.fill_(fill)
    
        if len(tensor_shape) == 1:
            extended_tensor[:tensor_shape[0]] = tensor
        elif len(tensor_shape) == 2:
            extended_tensor[:tensor_shape[0], :tensor_shape[1]] = tensor
        elif len(tensor_shape) == 3:
            extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2]] = tensor
        elif len(tensor_shape) == 4:
            extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2], :tensor_shape[3]] = tensor
    
        return extended_tensor
    

    对应spert_trainer.py中这一段代码:

    data_loader = DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, drop_last=True,
                             num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding)
    
    • 建立模型(models.py)

      class SpERT(BertPreTrainedModel):
        def __init__(self, config: BertConfig, cls_token: int, relation_types: int, entity_types: int,
                     size_embedding: int, prop_drop: float, freeze_transformer: bool, max_pairs: int = 100):
            super(SpERT, self).__init__(config)
      
            # BERT model
            self.bert = BertModel(config)
      
            # layers
            self.rel_classifier = nn.Linear(config.hidden_size * 3 + size_embedding * 2, relation_types)
            self.entity_classifier = nn.Linear(config.hidden_size * 2 + size_embedding, entity_types)
            self.size_embeddings = nn.Embedding(100, size_embedding)
            self.dropout = nn.Dropout(prop_drop)
      
            self._cls_token = cls_token
            self._relation_types = relation_types
            self._entity_types = entity_types
            self._max_pairs = max_pairs
      

      训练中所有出现的张量的维度总结:

    维度 用途
    encodings [batch, seq] 句子的编码,long型,代表字的id
    context_masks(attention_masks) [batch, seq] 句子的掩码,将进行了padding的位置标识为False,bool型
    h [batch,seq,hidden=758] 输入encodings和context_masks之后,得到的bert模型的输出
    entity_masks [batch, num_of_entity, seq] 实体的掩码,将实体的嵌入词向量的位置标识为True,bool型
    entity_sizes [batch, num_of_entity] 实体的长度,long型
    size_embeddings [batch, num_of_entity, size_embedding_dim=25] 对实体长度通过nn.Embedding编码,默认dim=25,float型
    entity_clf [batch, num_of_entity, cls_of_entity] 输出每个实体类别的预测值,float型,最后一个维度取最大值的下标即为类别
    entity_spans_pool [batch, num_of_entity, hidden=768] 对entity进行了最大池化,用于预测实体边界,float型
    h_large [batch, pair_of_relations, seq, hidden=768] 在h的第一个维度上进行repeat共pair_of_relations次,float型
    rel_mask [batch, pair_of_relations, seq] 对两个实体之间的编码进行mask,bool型
    relations [batch, pair_of_relations, 2] 记录每一对三元组的主实体和客实体下标,long型
    rel_clf [batch, pair_of_relation, cls_of_relation] 输入每个关系类别的预测值,float型,最后一个维度取最大值的下标即为类别

    训练前向传播

     def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor,
                           entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor):
            # get contextualized token embeddings from last transformer layer
            # encoding.shape=[batch, seq]
            context_masks = context_masks.float()
            h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]
            # h.shape=[batch, seq, hidden=768]
            batch_size = encodings.shape[0]
            # classify entities
            size_embeddings = self.size_embeddings(entity_sizes)  # embed entity candidate sizes
    
            # size_embeddings.shape=[batch, num_of_entity, size_embedding=25]
            # entity_size.shape=[batch, num_of_entity]
            # entity_masks.shape=[batch, num_of_entity, seq]
            entity_clf, entity_spans_pool = self._classify_entities(encodings, h, entity_masks, size_embeddings)
            # entity_clf.shape=[batch, num_of_entity, cls_of_entity]
            # entity_spans_pool.shape=[batch, num_of_entity, hidden=768]
    
            # classify relations
            h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
            rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to(
                self.rel_classifier.weight.device)
            # ref_clf.shape=[batch, pair_of_relation, types_of_relation]
    
            # obtain relation logits
            # chunk processing to reduce memory usage
    
            # relations.shape=[batch, pair_of_relation, 2]
            # relation_mask.shape=[batch, pair_of_relation, seq]
            for i in range(0, relations.shape[1], self._max_pairs):  # what's the use of for loop?
                # classify relation candidates
                chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
                                                            relations, rel_masks, h_large, i)
                rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits
    
            return entity_clf, rel_clf
    

    验证前向传播

        def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor,
                          entity_sizes: torch.tensor, entity_spans: torch.tensor, entity_sample_masks: torch.tensor):
            # get contextualized token embeddings from last transformer layer
            context_masks = context_masks.float()
            h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]
    
            batch_size = encodings.shape[0]
            ctx_size = context_masks.shape[-1]
    
            # classify entities
            size_embeddings = self.size_embeddings(entity_sizes)  # embed entity candidate sizes
            entity_clf, entity_spans_pool = self._classify_entities(encodings, h, entity_masks, size_embeddings)
    
            # ignore entity candidates that do not constitute an actual entity for relations (based on classifier)
            relations, rel_masks, rel_sample_masks = self._filter_spans(entity_clf, entity_spans,
                                                                        entity_sample_masks, ctx_size)
    
            rel_sample_masks = rel_sample_masks.float().unsqueeze(-1)
            h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
            rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to(
                self.rel_classifier.weight.device)
    
            # obtain relation logits
            # chunk processing to reduce memory usage
            for i in range(0, relations.shape[1], self._max_pairs):
                # classify relation candidates
                chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
                                                            relations, rel_masks, h_large, i)
                # apply sigmoid
                chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
                rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
    
            rel_clf = rel_clf * rel_sample_masks  # mask
    
            # apply softmax
            entity_clf = torch.softmax(entity_clf, dim=2)
    
            return entity_clf, rel_clf, relations
    

    实体分类

        def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
            # max pool entity candidate spans
            # encoding:[batch, seq]
            # h:[batch, seq, hidden]
            # entity_masks:[batch, num_of_entity, seq]
            # size_embeddings:[batch, num_of_entity, size_embedding=25]
            # entity_spans_pool:[batch, num_of_entity, hidden=768]
    
            m = (entity_masks.unsqueeze(-1) == 0).float() * (-1e30)
            entity_spans_pool = m + h.unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
            entity_spans_pool = entity_spans_pool.max(dim=2)[0]
    
            # get cls token as candidate context representation
            entity_ctx = get_token(h, encodings, self._cls_token)
            # entity_ctx:[batch, hidden=768]
            # create candidate representations including context, max pooled span and size embedding
            entity_repr = torch.cat([entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
                                     entity_spans_pool, size_embeddings], dim=2)
            # entity_repr:[batch, num_of_entity, hidden*2+size_embedding=768*2+25=1561]
            entity_repr = self.dropout(entity_repr)
    
            # classify entity candidates
            # entity_clf:[batch, num_of_entity, entity_types]
            entity_clf = self.entity_classifier(entity_repr)
    
            return entity_clf, entity_spans_pool
    

    关系分类

        def _classify_relations(self, entity_spans, size_embeddings, relations, rel_masks, h, chunk_start):
            batch_size = relations.shape[0]
    
            # entity_spans_pool.shape=[batch, num_of_entity, hidden=768]
            # size_embeddings.shape=[batch, num_of_entity, size_embedding=25]
            # relations.shape=[batch, pair_of_relation, 2]
            # relation_mask.shape=[batch, pair_of_relation, seq]
            # h.shape=[batch, min(pair_of_relation, self._max_pairs), seq, hidden = 768]
    
            # create chunks if necessary
            if relations.shape[1] > self._max_pairs:
                relations = relations[:, chunk_start:chunk_start + self._max_pairs]
                rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs]
                h = h[:, :relations.shape[1], :]
    
            # get pairs of entity candidate representations
            # print("entity_span", entity_spans.shape)
            # print("relations", relations.shape)
            entity_pairs = util.batch_index(entity_spans, relations)
            # entity_pairs.shape=[batch, pair_of_relation, 2, hidden=768]
            entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)
            # entity_pairs.shape=[batch, pair_of_relation, 2*hidden=1536]
            # get corresponding size embeddings
            size_pair_embeddings = util.batch_index(size_embeddings, relations)
            # size_pair_embeddings=[batch, pair_of_relation, 2, size_embedding=25]
            size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1)
            # size_pair_embeddings=[batch, pair_of_relation, 2 * size_embedding=25]
    
            # relation context (context between entity candidate pair)
            # mask non entity candidate tokens
            m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1)
    
            rel_ctx = m + h
    
            # rel_ctx=[batch, min(pair_of_relation, self._max_pairs), seq, hidden=768]
            # max pooling
            rel_ctx = rel_ctx.max(dim=2)[0]
            # rel_ctx=[batch, min(pair_of_relation, self._max_pairs), hidden=768]
            # set the context vector of neighboring or adjacent entity candidates to zero
            rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0
    
            # create relation candidate representations including context, max pooled entity candidate pairs
            # and corresponding size embeddings
            rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings], dim=2)
            # rel_repr=[batch, pair_of_relation, 3*hidden+size_embedding]
            rel_repr = self.dropout(rel_repr)
    
            # classify relation candidates
            chunk_rel_logits = self.rel_classifier(rel_repr)
            # chunk_rel_logits=[batch, pair_of_relation, relation_types]
            return chunk_rel_logits
    

    边界过滤

        def _filter_spans(self, entity_clf, entity_spans, entity_sample_masks, ctx_size):
            batch_size = entity_clf.shape[0]
            entity_logits_max = entity_clf.argmax(dim=-1) * entity_sample_masks.long()  # get entity type (including none)
            batch_relations = []
            batch_rel_masks = []
            batch_rel_sample_masks = []
    
            for i in range(batch_size):
                rels = []
                rel_masks = []
                sample_masks = []
    
                # get spans classified as entities
                non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1)
                non_zero_spans = entity_spans[i][non_zero_indices].tolist()
                non_zero_indices = non_zero_indices.tolist()
    
                # create relations and masks
                for i1, s1 in zip(non_zero_indices, non_zero_spans):
                    for i2, s2 in zip(non_zero_indices, non_zero_spans):
                        if i1 != i2:
                            rels.append((i1, i2))
                            rel_masks.append(sampling.create_rel_mask(s1, s2, ctx_size))
                            sample_masks.append(1)
    
                if not rels:
                    # case: no more than two spans classified as entities
                    batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long))
                    batch_rel_masks.append(torch.tensor([[0] * ctx_size], dtype=torch.bool))
                    batch_rel_sample_masks.append(torch.tensor([0], dtype=torch.bool))
                else:
                    # case: more than two spans classified as entities
                    batch_relations.append(torch.tensor(rels, dtype=torch.long))
                    batch_rel_masks.append(torch.stack(rel_masks))
                    batch_rel_sample_masks.append(torch.tensor(sample_masks, dtype=torch.bool))
    
            # stack
            device = self.rel_classifier.weight.device
            batch_relations = util.padded_stack(batch_relations).to(device)
            batch_rel_masks = util.padded_stack(batch_rel_masks).to(device)
            batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(device)
    
            return batch_relations, batch_rel_masks, batch_rel_sample_masks