首先是从测试代码出发,可以看到实例化的BART对象是一个BARTMODEL。使用from pretrain就是获得了一个hub
Hub_Utils
generator = self.task.build_generator(self.models,gen_args,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,)inference_step_args = inference_step_args or {}results = []for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)translations = self.task.inference_step(generator, self.models, batch, **inference_step_args)for id, hypos in zip(batch["id"].tolist(), translations):results.append((id, hypos))# sort output to match input orderoutputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
Translation Task使用了Fairseq task里的build generator
测试的时候调用普通的beamsearch
search_strategy = search.BeamSearch(self.target_dictionary)
调用返回一个普通的SequenceGenerator
if seq_gen_cls is None:if getattr(args, "print_alignment", False):seq_gen_cls = SequenceGeneratorWithAlignmentextra_gen_cls_kwargs["print_alignment"] = args.print_alignmentelse:print("common generator")seq_gen_cls = SequenceGenerator
self.task.inference_step同样来自于fairseq_task里面的
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):with torch.no_grad():return generator.generate(models, sample, prefix_tokens=prefix_tokens, constraints=constraints)
然后回到generator,调用的generate套壳到_generate
net_input = sample["net_input"]'''{'src_tokens': tensor([[ 0, 1594, 2654, ..., 655, 479, 2],[ 1, 1, 1, ..., 212, 479, 2],[ 1, 1, 1, ..., 479, 12801, 2],[ 1, 1, 1, ..., 1676, 479, 2]], device='cuda:0'),'src_lengths': tensor([410, 394, 345, 340], device='cuda:0')}'''
generator先forward encoder
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):encoder_outs = self.model.forward_encoder(net_input)
这里的self.model是一个ensembleModel类型,就在同文件
然后Ensemble的内部的forward_encoder如下,套壳了一个@torch.jit.export
这玩意儿似乎是一个新的Torch的feature,可以加速啥的
@torch.jit.exportdef forward_encoder(self, net_input: Dict[str, Tensor]):if not self.has_encoder():return None#print(self.models)return [model.encoder.forward_torchscript(net_input) for model in self.models]
然后在fairseq_encoder里找到了这个函数
def forward_torchscript(self, net_input: Dict[str, Tensor]):"""A TorchScript-compatible version of forward.Encoders which use additional arguments may want to overridethis method for TorchScript compatibility."""if torch.jit.is_scripting():print("Hit jit forwarding")return self.forward(src_tokens=net_input["src_tokens"],src_lengths=net_input["src_lengths"],)else:print("Hit nonjit forwarding")return self.forward_non_torchscript(net_input)
这里调用了non torchscript
@torch.jit.unuseddef forward_non_torchscript(self, net_input: Dict[str, Tensor]):encoder_input = {k: v for k, v in net_input.items() if k != "prev_output_tokens"}return self.forward(**encoder_input)
最诡异的事情在于这里的forward,居然调用的不是自己的forward(这里fairseq_encoder的forward标注为not implemented)调用的是子类transformerEncoderBase的forward。forward里一转手调用了forward_scriptable。接着就是正常的forward过程了。
然后是解码过程
如下的这个代码块是对每一个sample里的beam进行排序
以运行Xsum为例,一个batch是7,beamsize是6
则finalized是一个大小是7的列表,每一个列表是一个大小为6的小列表
每一个小列表是一个字典。字典里面有:
tokensscoreattentionalignmentpositional_scores
for sent in range(len(finalized)):scores = torch.tensor([float(elem["score"].item()) for elem in finalized[sent]])_, sorted_scores_indices = torch.sort(scores, descending=True)finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], finalized[sent])return finalized
