- 数据格式
单条JSON数据格式如下:
{
"tokens": [
"Newspaper"
//...
],
"entities": [
{
"type": "Loc",
"start": 4, //Entity提取自tokens[start:end]
"end": 5
}
//...
],
"relations": [
{
"type": "OrgBased_In",
"head": 2, //对应头实体的下标
"tail": 1
}
],
"orig_id": 3255
}
- 处理JSON数据(input_reader.py)
每一组JSON数据映射到一个Document。一个Document由Token、Entity、Relation组成。
def _parse_document(self, doc, dataset) -> Document:
jtokens = doc['tokens']
jrelations = doc['relations']
jentities = doc['entities']
# parse tokens
doc_tokens, doc_encoding = self._parse_tokens(jtokens, dataset)
# parse entity mentions
entities = self._parse_entities(jentities, doc_tokens, dataset)
# parse relations
relations = self._parse_relations(jrelations, entities, dataset)
# create document
document = dataset.create_document(doc_tokens, entities, relations, doc_encoding)
return document
类的定义(entities.py)
class Token:
def __init__(self, tid: int, index: int, span_start: int, span_end: int, phrase: str):
self._tid = tid # ID within the corresponding dataset
self._index = index # original token index in document
self._span_start = span_start # start of token span in document (inclusive)
self._span_end = span_end # end of token span in document (exclusive)
self._phrase = phrase
一个Entity可能由多个Token组成。
class Entity:
def __init__(self, eid: int, entity_type: EntityType, tokens: List[Token], phrase: str):
self._eid = eid # ID within the corresponding dataset
self._entity_type = entity_type
self._tokens = tokens
self._phrase = phrase
@property
def tokens(self):
return TokenSpan(self._tokens)
Relation中包含头实体和尾实体,考虑了关系的对称性。
class Relation:
def __init__(self, rid: int, relation_type: RelationType, head_entity: Entity,
tail_entity: Entity, reverse: bool = False):
self._rid = rid # ID within the corresponding dataset
self._relation_type = relation_type
self._head_entity = head_entity
self._tail_entity = tail_entity
self._reverse = reverse
self._first_entity = head_entity if not reverse else tail_entity
self._second_entity = tail_entity if not reverse else head_entity
TokenSpan类是为了便于获得Entity的Span,内置属性为一组Tokens,返回Span是第一组Token的start和最后一组Token的end。
class TokenSpan:
def __init__(self, tokens):
self._tokens = tokens
@property
def span_start(self):
return self._tokens[0].span_start
@property
def span_end(self):
return self._tokens[-1].span_end
@property
def span(self):
return self.span_start, self.span_end
解析Token、Entity和Relation(input_reader.py)
def _parse_tokens(self, jtokens, dataset):
doc_tokens = []
# full document encoding including special tokens ([CLS] and [SEP]) and byte-pair encodings of original tokens
doc_encoding = [self._tokenizer.convert_tokens_to_ids('[CLS]')]
# parse tokens
for i, token_phrase in enumerate(jtokens):
token_encoding = self._tokenizer.encode(token_phrase, add_special_tokens=False)
span_start, span_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding))
token = dataset.create_token(i, span_start, span_end, token_phrase)
doc_tokens.append(token)
doc_encoding += token_encoding
doc_encoding += [self._tokenizer.convert_tokens_to_ids('[SEP]')]
return doc_tokens, doc_encoding
def _parse_entities(self, jentities, doc_tokens, dataset) -> List[Entity]:
entities = []
for entity_idx, jentity in enumerate(jentities):
entity_type = self._entity_types[jentity['type']]
start, end = jentity['start'], jentity['end']
# create entity mention
tokens = doc_tokens[start:end]
phrase = " ".join([t.phrase for t in tokens])
entity = dataset.create_entity(entity_type, tokens, phrase)
entities.append(entity)
return entities
def _parse_relations(self, jrelations, entities, dataset) -> List[Relation]:
relations = []
for jrelation in jrelations:
relation_type = self._relation_types[jrelation['type']]
head_idx = jrelation['head']
tail_idx = jrelation['tail']
# create relation
head = entities[head_idx]
tail = entities[tail_idx]
reverse = int(tail.tokens[0].index) < int(head.tokens[0].index)
# for symmetric relations: head occurs before tail in sentence
if relation_type.symmetric and reverse:
head, tail = util.swap(head, tail)
relation = dataset.create_relation(relation_type, head_entity=head, tail_entity=tail, reverse=reverse)
relations.append(relation)
return relations
创建Dataset(entities.py)
class Dataset(TorchDataset):
TRAIN_MODE = 'train'
EVAL_MODE = 'eval'
def __init__(self, label, rel_types, entity_types, neg_entity_count,
neg_rel_count, max_span_size):
self._label = label
self._rel_types = rel_types
self._entity_types = entity_types
self._neg_entity_count = neg_entity_count
self._neg_rel_count = neg_rel_count
self._max_span_size = max_span_size
self._mode = Dataset.TRAIN_MODE
self._documents = OrderedDict()
self._entities = OrderedDict()
self._relations = OrderedDict()
# current ids
self._doc_id = 0
self._rid = 0
self._eid = 0
self._tid = 0
def __getitem__(self, index: int):
doc = self._documents[index]
if self._mode == Dataset.TRAIN_MODE:
return sampling.create_train_sample(doc, self._neg_entity_count, self._neg_rel_count,
self._max_span_size, len(self._rel_types))
else:
return sampling.create_eval_sample(doc, self._max_span_size)
采样(Sampling.py)
创建实体mask。
def create_entity_mask(start, end, context_size):
mask = torch.zeros(context_size, dtype=torch.bool)
mask[start:end] = 1
return mask
创建关系mask,将首位实体之间的context置为1。
def create_rel_mask(s1, s2, context_size):
start = s1[1] if s1[1] < s2[0] else s2[1]
end = s2[0] if s1[1] < s2[0] else s1[0]
mask = create_entity_mask(start, end, context_size)
return mask
创建训练样本
特色点为负采样,选择的是实体,但不存在关系。没有关系的实体关系置为0。
def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span_size: int, rel_type_count: int):
encodings = doc.encoding
token_count = len(doc.tokens)
context_size = len(encodings)
# positive entities
pos_entity_spans, pos_entity_types, pos_entity_masks, pos_entity_sizes = [], [], [], []
for e in doc.entities:
pos_entity_spans.append(e.span)
pos_entity_types.append(e.entity_type.index)
pos_entity_masks.append(create_entity_mask(*e.span, context_size))
pos_entity_sizes.append(len(e.tokens))
# positive relations
pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []
for rel in doc.relations:
s1, s2 = rel.head_entity.span, rel.tail_entity.span
pos_rels.append((pos_entity_spans.index(s1), pos_entity_spans.index(s2)))
pos_rel_spans.append((s1, s2))
pos_rel_types.append(rel.relation_type)
pos_rel_masks.append(create_rel_mask(s1, s2, context_size))
# negative entities
neg_entity_spans, neg_entity_sizes = [], []
for size in range(1, max_span_size + 1):
for i in range(0, (token_count - size) + 1):
span = doc.tokens[i:i + size].span
if span not in pos_entity_spans:
neg_entity_spans.append(span)
neg_entity_sizes.append(size)
# sample negative entities
neg_entity_samples = random.sample(list(zip(neg_entity_spans, neg_entity_sizes)),
min(len(neg_entity_spans), neg_entity_count))
neg_entity_spans, neg_entity_sizes = zip(*neg_entity_samples) if neg_entity_samples else ([], [])
neg_entity_masks = [create_entity_mask(*span, context_size) for span in neg_entity_spans]
neg_entity_types = [0] * len(neg_entity_spans)
# negative relations
# use only strong negative relations, i.e. pairs of actual (labeled) entities that are not related
neg_rel_spans = []
for i1, s1 in enumerate(pos_entity_spans):
for i2, s2 in enumerate(pos_entity_spans):
rev = (s2, s1)
rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric
# do not add as negative relation sample:
# neg. relations from an entity to itself
# entity pairs that are related according to gt
# entity pairs whose reverse exists as a symmetric relation in gt
if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric:
neg_rel_spans.append((s1, s2))
# sample negative relations
neg_rel_spans = random.sample(neg_rel_spans, min(len(neg_rel_spans), neg_rel_count))
neg_rels = [(pos_entity_spans.index(s1), pos_entity_spans.index(s2)) for s1, s2 in neg_rel_spans]
neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans]
neg_rel_types = [0] * len(neg_rel_spans)
# merge
entity_types = pos_entity_types + neg_entity_types
entity_masks = pos_entity_masks + neg_entity_masks
entity_sizes = pos_entity_sizes + list(neg_entity_sizes)
rels = pos_rels + neg_rels
rel_types = [r.index for r in pos_rel_types] + neg_rel_types
rel_masks = pos_rel_masks + neg_rel_masks
assert len(entity_masks) == len(entity_sizes) == len(entity_types)
assert len(rels) == len(rel_masks) == len(rel_types)
# create tensors
# token indices
encodings = torch.tensor(encodings, dtype=torch.long)
# masking of tokens
context_masks = torch.ones(context_size, dtype=torch.bool)
# also create samples_masks:
# tensors to mask entity/relation samples of batch
# since samples are stacked into batches, "padding" entities/relations possibly must be created
# these are later masked during loss computation
if entity_masks:
entity_types = torch.tensor(entity_types, dtype=torch.long)
entity_masks = torch.stack(entity_masks)
entity_sizes = torch.tensor(entity_sizes, dtype=torch.long)
entity_sample_masks = torch.ones([entity_masks.shape[0]], dtype=torch.bool)
else:
# corner case handling (no pos/neg entities)
entity_types = torch.zeros([1], dtype=torch.long)
entity_masks = torch.zeros([1, context_size], dtype=torch.bool)
entity_sizes = torch.zeros([1], dtype=torch.long)
entity_sample_masks = torch.zeros([1], dtype=torch.bool)
if rels:
rels = torch.tensor(rels, dtype=torch.long)
rel_masks = torch.stack(rel_masks)
rel_types = torch.tensor(rel_types, dtype=torch.long)
rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool)
else:
# corner case handling (no pos/neg relations)
rels = torch.zeros([1, 2], dtype=torch.long)
rel_types = torch.zeros([1], dtype=torch.long)
rel_masks = torch.zeros([1, context_size], dtype=torch.bool)
rel_sample_masks = torch.zeros([1], dtype=torch.bool)
# relation types to one-hot encoding
rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation
return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
entity_sizes=entity_sizes, entity_types=entity_types,
rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)
创建验证样本
def create_eval_sample(doc, max_span_size: int):
encodings = doc.encoding
token_count = len(doc.tokens)
context_size = len(encodings)
# create entity candidates
entity_spans = []
entity_masks = []
entity_sizes = []
for size in range(1, max_span_size + 1):
for i in range(0, (token_count - size) + 1):
span = doc.tokens[i:i + size].span
entity_spans.append(span)
entity_masks.append(create_entity_mask(*span, context_size))
entity_sizes.append(size)
# create tensors
# token indices
_encoding = encodings
encodings = torch.zeros(context_size, dtype=torch.long)
encodings[:len(_encoding)] = torch.tensor(_encoding, dtype=torch.long)
# masking of tokens
context_masks = torch.zeros(context_size, dtype=torch.bool)
context_masks[:len(_encoding)] = 1
# entities
if entity_masks:
entity_masks = torch.stack(entity_masks)
entity_sizes = torch.tensor(entity_sizes, dtype=torch.long)
entity_spans = torch.tensor(entity_spans, dtype=torch.long)
# tensors to mask entity samples of batch
# since samples are stacked into batches, "padding" entities possibly must be created
# these are later masked during evaluation
entity_sample_masks = torch.tensor([1] * entity_masks.shape[0], dtype=torch.bool)
else:
# corner case handling (no entities)
entity_masks = torch.zeros([1, context_size], dtype=torch.bool)
entity_sizes = torch.zeros([1], dtype=torch.long)
entity_spans = torch.zeros([1, 2], dtype=torch.long)
entity_sample_masks = torch.zeros([1], dtype=torch.bool)
return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
entity_sizes=entity_sizes, entity_spans=entity_spans, entity_sample_masks=entity_sample_masks)
每次返回的数据类型为dict,实际上每个batch中的数据维度是不一样的,所以还需要自定义collate_fn使得每个batch内的数据填充到该batch的最大长度,使得数据维度相同。
def collate_fn_padding(batch):
padded_batch = dict()
keys = batch[0].keys()
for key in keys:
samples = [s[key] for s in batch]
if not batch[0][key].shape:
padded_batch[key] = torch.stack(samples)
else:
padded_batch[key] = util.padded_stack([s[key] for s in batch])
return padded_batch
util中的padding_stack:
def padded_stack(tensors, padding=0):
dim_count = len(tensors[0].shape)
max_shape = [max([t.shape[d] for t in tensors]) for d in range(dim_count)]
padded_tensors = []
for t in tensors:
e = extend_tensor(t, max_shape, fill=padding)
padded_tensors.append(e)
stacked = torch.stack(padded_tensors)
return stacked
def extend_tensor(tensor, extended_shape, fill=0):
tensor_shape = tensor.shape
extended_tensor = torch.zeros(extended_shape, dtype=tensor.dtype).to(tensor.device)
extended_tensor = extended_tensor.fill_(fill)
if len(tensor_shape) == 1:
extended_tensor[:tensor_shape[0]] = tensor
elif len(tensor_shape) == 2:
extended_tensor[:tensor_shape[0], :tensor_shape[1]] = tensor
elif len(tensor_shape) == 3:
extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2]] = tensor
elif len(tensor_shape) == 4:
extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2], :tensor_shape[3]] = tensor
return extended_tensor
对应spert_trainer.py中这一段代码:
data_loader = DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, drop_last=True,
num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding)
建立模型(models.py)
class SpERT(BertPreTrainedModel): def __init__(self, config: BertConfig, cls_token: int, relation_types: int, entity_types: int, size_embedding: int, prop_drop: float, freeze_transformer: bool, max_pairs: int = 100): super(SpERT, self).__init__(config) # BERT model self.bert = BertModel(config) # layers self.rel_classifier = nn.Linear(config.hidden_size * 3 + size_embedding * 2, relation_types) self.entity_classifier = nn.Linear(config.hidden_size * 2 + size_embedding, entity_types) self.size_embeddings = nn.Embedding(100, size_embedding) self.dropout = nn.Dropout(prop_drop) self._cls_token = cls_token self._relation_types = relation_types self._entity_types = entity_types self._max_pairs = max_pairs
训练中所有出现的张量的维度总结:
维度 | 用途 | |
---|---|---|
encodings | [batch, seq] | 句子的编码,long型,代表字的id |
context_masks(attention_masks) | [batch, seq] | 句子的掩码,将进行了padding的位置标识为False,bool型 |
h | [batch,seq,hidden=758] | 输入encodings和context_masks之后,得到的bert模型的输出 |
entity_masks | [batch, num_of_entity, seq] | 实体的掩码,将实体的嵌入词向量的位置标识为True,bool型 |
entity_sizes | [batch, num_of_entity] | 实体的长度,long型 |
size_embeddings | [batch, num_of_entity, size_embedding_dim=25] | 对实体长度通过nn.Embedding编码,默认dim=25,float型 |
entity_clf | [batch, num_of_entity, cls_of_entity] | 输出每个实体类别的预测值,float型,最后一个维度取最大值的下标即为类别 |
entity_spans_pool | [batch, num_of_entity, hidden=768] | 对entity进行了最大池化,用于预测实体边界,float型 |
h_large | [batch, pair_of_relations, seq, hidden=768] | 在h的第一个维度上进行repeat共pair_of_relations次,float型 |
rel_mask | [batch, pair_of_relations, seq] | 对两个实体之间的编码进行mask,bool型 |
relations | [batch, pair_of_relations, 2] | 记录每一对三元组的主实体和客实体下标,long型 |
rel_clf | [batch, pair_of_relation, cls_of_relation] | 输入每个关系类别的预测值,float型,最后一个维度取最大值的下标即为类别 |
训练前向传播
def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor,
entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor):
# get contextualized token embeddings from last transformer layer
# encoding.shape=[batch, seq]
context_masks = context_masks.float()
h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]
# h.shape=[batch, seq, hidden=768]
batch_size = encodings.shape[0]
# classify entities
size_embeddings = self.size_embeddings(entity_sizes) # embed entity candidate sizes
# size_embeddings.shape=[batch, num_of_entity, size_embedding=25]
# entity_size.shape=[batch, num_of_entity]
# entity_masks.shape=[batch, num_of_entity, seq]
entity_clf, entity_spans_pool = self._classify_entities(encodings, h, entity_masks, size_embeddings)
# entity_clf.shape=[batch, num_of_entity, cls_of_entity]
# entity_spans_pool.shape=[batch, num_of_entity, hidden=768]
# classify relations
h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to(
self.rel_classifier.weight.device)
# ref_clf.shape=[batch, pair_of_relation, types_of_relation]
# obtain relation logits
# chunk processing to reduce memory usage
# relations.shape=[batch, pair_of_relation, 2]
# relation_mask.shape=[batch, pair_of_relation, seq]
for i in range(0, relations.shape[1], self._max_pairs): # what's the use of for loop?
# classify relation candidates
chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
relations, rel_masks, h_large, i)
rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits
return entity_clf, rel_clf
验证前向传播
def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor,
entity_sizes: torch.tensor, entity_spans: torch.tensor, entity_sample_masks: torch.tensor):
# get contextualized token embeddings from last transformer layer
context_masks = context_masks.float()
h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]
batch_size = encodings.shape[0]
ctx_size = context_masks.shape[-1]
# classify entities
size_embeddings = self.size_embeddings(entity_sizes) # embed entity candidate sizes
entity_clf, entity_spans_pool = self._classify_entities(encodings, h, entity_masks, size_embeddings)
# ignore entity candidates that do not constitute an actual entity for relations (based on classifier)
relations, rel_masks, rel_sample_masks = self._filter_spans(entity_clf, entity_spans,
entity_sample_masks, ctx_size)
rel_sample_masks = rel_sample_masks.float().unsqueeze(-1)
h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to(
self.rel_classifier.weight.device)
# obtain relation logits
# chunk processing to reduce memory usage
for i in range(0, relations.shape[1], self._max_pairs):
# classify relation candidates
chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
relations, rel_masks, h_large, i)
# apply sigmoid
chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
rel_clf = rel_clf * rel_sample_masks # mask
# apply softmax
entity_clf = torch.softmax(entity_clf, dim=2)
return entity_clf, rel_clf, relations
实体分类
def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
# max pool entity candidate spans
# encoding:[batch, seq]
# h:[batch, seq, hidden]
# entity_masks:[batch, num_of_entity, seq]
# size_embeddings:[batch, num_of_entity, size_embedding=25]
# entity_spans_pool:[batch, num_of_entity, hidden=768]
m = (entity_masks.unsqueeze(-1) == 0).float() * (-1e30)
entity_spans_pool = m + h.unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
entity_spans_pool = entity_spans_pool.max(dim=2)[0]
# get cls token as candidate context representation
entity_ctx = get_token(h, encodings, self._cls_token)
# entity_ctx:[batch, hidden=768]
# create candidate representations including context, max pooled span and size embedding
entity_repr = torch.cat([entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
entity_spans_pool, size_embeddings], dim=2)
# entity_repr:[batch, num_of_entity, hidden*2+size_embedding=768*2+25=1561]
entity_repr = self.dropout(entity_repr)
# classify entity candidates
# entity_clf:[batch, num_of_entity, entity_types]
entity_clf = self.entity_classifier(entity_repr)
return entity_clf, entity_spans_pool
关系分类
def _classify_relations(self, entity_spans, size_embeddings, relations, rel_masks, h, chunk_start):
batch_size = relations.shape[0]
# entity_spans_pool.shape=[batch, num_of_entity, hidden=768]
# size_embeddings.shape=[batch, num_of_entity, size_embedding=25]
# relations.shape=[batch, pair_of_relation, 2]
# relation_mask.shape=[batch, pair_of_relation, seq]
# h.shape=[batch, min(pair_of_relation, self._max_pairs), seq, hidden = 768]
# create chunks if necessary
if relations.shape[1] > self._max_pairs:
relations = relations[:, chunk_start:chunk_start + self._max_pairs]
rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs]
h = h[:, :relations.shape[1], :]
# get pairs of entity candidate representations
# print("entity_span", entity_spans.shape)
# print("relations", relations.shape)
entity_pairs = util.batch_index(entity_spans, relations)
# entity_pairs.shape=[batch, pair_of_relation, 2, hidden=768]
entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)
# entity_pairs.shape=[batch, pair_of_relation, 2*hidden=1536]
# get corresponding size embeddings
size_pair_embeddings = util.batch_index(size_embeddings, relations)
# size_pair_embeddings=[batch, pair_of_relation, 2, size_embedding=25]
size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1)
# size_pair_embeddings=[batch, pair_of_relation, 2 * size_embedding=25]
# relation context (context between entity candidate pair)
# mask non entity candidate tokens
m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1)
rel_ctx = m + h
# rel_ctx=[batch, min(pair_of_relation, self._max_pairs), seq, hidden=768]
# max pooling
rel_ctx = rel_ctx.max(dim=2)[0]
# rel_ctx=[batch, min(pair_of_relation, self._max_pairs), hidden=768]
# set the context vector of neighboring or adjacent entity candidates to zero
rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0
# create relation candidate representations including context, max pooled entity candidate pairs
# and corresponding size embeddings
rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings], dim=2)
# rel_repr=[batch, pair_of_relation, 3*hidden+size_embedding]
rel_repr = self.dropout(rel_repr)
# classify relation candidates
chunk_rel_logits = self.rel_classifier(rel_repr)
# chunk_rel_logits=[batch, pair_of_relation, relation_types]
return chunk_rel_logits
边界过滤
def _filter_spans(self, entity_clf, entity_spans, entity_sample_masks, ctx_size):
batch_size = entity_clf.shape[0]
entity_logits_max = entity_clf.argmax(dim=-1) * entity_sample_masks.long() # get entity type (including none)
batch_relations = []
batch_rel_masks = []
batch_rel_sample_masks = []
for i in range(batch_size):
rels = []
rel_masks = []
sample_masks = []
# get spans classified as entities
non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1)
non_zero_spans = entity_spans[i][non_zero_indices].tolist()
non_zero_indices = non_zero_indices.tolist()
# create relations and masks
for i1, s1 in zip(non_zero_indices, non_zero_spans):
for i2, s2 in zip(non_zero_indices, non_zero_spans):
if i1 != i2:
rels.append((i1, i2))
rel_masks.append(sampling.create_rel_mask(s1, s2, ctx_size))
sample_masks.append(1)
if not rels:
# case: no more than two spans classified as entities
batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long))
batch_rel_masks.append(torch.tensor([[0] * ctx_size], dtype=torch.bool))
batch_rel_sample_masks.append(torch.tensor([0], dtype=torch.bool))
else:
# case: more than two spans classified as entities
batch_relations.append(torch.tensor(rels, dtype=torch.long))
batch_rel_masks.append(torch.stack(rel_masks))
batch_rel_sample_masks.append(torch.tensor(sample_masks, dtype=torch.bool))
# stack
device = self.rel_classifier.weight.device
batch_relations = util.padded_stack(batch_relations).to(device)
batch_rel_masks = util.padded_stack(batch_rel_masks).to(device)
batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(device)
return batch_relations, batch_rel_masks, batch_rel_sample_masks