首先是从测试代码出发,可以看到实例化的BART对象是一个BARTMODEL。使用from pretrain就是获得了一个hub

    Hub_Utils

    1. generator = self.task.build_generator(
    2. self.models,
    3. gen_args,
    4. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    5. )
    6. inference_step_args = inference_step_args or {}
    7. results = []
    8. for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
    9. batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
    10. translations = self.task.inference_step(
    11. generator, self.models, batch, **inference_step_args
    12. )
    13. for id, hypos in zip(batch["id"].tolist(), translations):
    14. results.append((id, hypos))
    15. # sort output to match input order
    16. outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]

    Translation Task使用了Fairseq task里的build generator

    测试的时候调用普通的beamsearch

    1. search_strategy = search.BeamSearch(self.target_dictionary)

    调用返回一个普通的SequenceGenerator

    1. if seq_gen_cls is None:
    2. if getattr(args, "print_alignment", False):
    3. seq_gen_cls = SequenceGeneratorWithAlignment
    4. extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
    5. else:
    6. print("common generator")
    7. seq_gen_cls = SequenceGenerator

    self.task.inference_step同样来自于fairseq_task里面的

    1. def inference_step(
    2. self, generator, models, sample, prefix_tokens=None, constraints=None
    3. ):
    4. with torch.no_grad():
    5. return generator.generate(
    6. models, sample, prefix_tokens=prefix_tokens, constraints=constraints
    7. )

    然后回到generator,调用的generate套壳到_generate

    1. net_input = sample["net_input"]
    2. '''
    3. {'src_tokens': tensor([[ 0, 1594, 2654, ..., 655, 479, 2],
    4. [ 1, 1, 1, ..., 212, 479, 2],
    5. [ 1, 1, 1, ..., 479, 12801, 2],
    6. [ 1, 1, 1, ..., 1676, 479, 2]], device='cuda:0'),
    7. 'src_lengths': tensor([410, 394, 345, 340], device='cuda:0')
    8. }
    9. '''

    generator先forward encoder

    1. with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
    2. encoder_outs = self.model.forward_encoder(net_input)

    这里的self.model是一个ensembleModel类型,就在同文件
    然后Ensemble的内部的forward_encoder如下,套壳了一个@torch.jit.export
    这玩意儿似乎是一个新的Torch的feature,可以加速啥的

    1. @torch.jit.export
    2. def forward_encoder(self, net_input: Dict[str, Tensor]):
    3. if not self.has_encoder():
    4. return None
    5. #print(self.models)
    6. return [model.encoder.forward_torchscript(net_input) for model in self.models]

    然后在fairseq_encoder里找到了这个函数

    1. def forward_torchscript(self, net_input: Dict[str, Tensor]):
    2. """A TorchScript-compatible version of forward.
    3. Encoders which use additional arguments may want to override
    4. this method for TorchScript compatibility.
    5. """
    6. if torch.jit.is_scripting():
    7. print("Hit jit forwarding")
    8. return self.forward(
    9. src_tokens=net_input["src_tokens"],
    10. src_lengths=net_input["src_lengths"],
    11. )
    12. else:
    13. print("Hit nonjit forwarding")
    14. return self.forward_non_torchscript(net_input)

    这里调用了non torchscript

    1. @torch.jit.unused
    2. def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
    3. encoder_input = {
    4. k: v for k, v in net_input.items() if k != "prev_output_tokens"
    5. }
    6. return self.forward(**encoder_input)

    最诡异的事情在于这里的forward,居然调用的不是自己的forward(这里fairseq_encoder的forward标注为not implemented)调用的是子类transformerEncoderBase的forward。forward里一转手调用了forward_scriptable。接着就是正常的forward过程了。

    然后是解码过程
    如下的这个代码块是对每一个sample里的beam进行排序
    以运行Xsum为例,一个batch是7,beamsize是6
    则finalized是一个大小是7的列表,每一个列表是一个大小为6的小列表
    每一个小列表是一个字典。字典里面有:

    1. tokens
    2. score
    3. attention
    4. alignment
    5. positional_scores
    1. for sent in range(len(finalized)):
    2. scores = torch.tensor(
    3. [float(elem["score"].item()) for elem in finalized[sent]]
    4. )
    5. _, sorted_scores_indices = torch.sort(scores, descending=True)
    6. finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
    7. finalized[sent] = torch.jit.annotate(
    8. List[Dict[str, Tensor]], finalized[sent]
    9. )
    10. return finalized