首先是从测试代码出发,可以看到实例化的BART对象是一个BARTMODEL。使用from pretrain就是获得了一个hub
Hub_Utils
generator = self.task.build_generator(
self.models,
gen_args,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
inference_step_args = inference_step_args or {}
results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
translations = self.task.inference_step(
generator, self.models, batch, **inference_step_args
)
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))
# sort output to match input order
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
Translation Task使用了Fairseq task里的build generator
测试的时候调用普通的beamsearch
search_strategy = search.BeamSearch(self.target_dictionary)
调用返回一个普通的SequenceGenerator
if seq_gen_cls is None:
if getattr(args, "print_alignment", False):
seq_gen_cls = SequenceGeneratorWithAlignment
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
else:
print("common generator")
seq_gen_cls = SequenceGenerator
self.task.inference_step同样来自于fairseq_task里面的
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, constraints=constraints
)
然后回到generator,调用的generate套壳到_generate
net_input = sample["net_input"]
'''
{'src_tokens': tensor([[ 0, 1594, 2654, ..., 655, 479, 2],
[ 1, 1, 1, ..., 212, 479, 2],
[ 1, 1, 1, ..., 479, 12801, 2],
[ 1, 1, 1, ..., 1676, 479, 2]], device='cuda:0'),
'src_lengths': tensor([410, 394, 345, 340], device='cuda:0')
}
'''
generator先forward encoder
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
encoder_outs = self.model.forward_encoder(net_input)
这里的self.model是一个ensembleModel类型,就在同文件
然后Ensemble的内部的forward_encoder如下,套壳了一个@torch.jit.export
这玩意儿似乎是一个新的Torch的feature,可以加速啥的
@torch.jit.export
def forward_encoder(self, net_input: Dict[str, Tensor]):
if not self.has_encoder():
return None
#print(self.models)
return [model.encoder.forward_torchscript(net_input) for model in self.models]
然后在fairseq_encoder里找到了这个函数
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
print("Hit jit forwarding")
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
print("Hit nonjit forwarding")
return self.forward_non_torchscript(net_input)
这里调用了non torchscript
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
最诡异的事情在于这里的forward,居然调用的不是自己的forward(这里fairseq_encoder的forward标注为not implemented)调用的是子类transformerEncoderBase的forward。forward里一转手调用了forward_scriptable。接着就是正常的forward过程了。
然后是解码过程
如下的这个代码块是对每一个sample里的beam进行排序
以运行Xsum为例,一个batch是7,beamsize是6
则finalized是一个大小是7的列表,每一个列表是一个大小为6的小列表
每一个小列表是一个字典。字典里面有:
tokens
score
attention
alignment
positional_scores
for sent in range(len(finalized)):
scores = torch.tensor(
[float(elem["score"].item()) for elem in finalized[sent]]
)
_, sorted_scores_indices = torch.sort(scores, descending=True)
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
finalized[sent] = torch.jit.annotate(
List[Dict[str, Tensor]], finalized[sent]
)
return finalized