定义

预训练模型(Pretrained model):

一般情况下预训练模型都是大型模型,具备复杂的网络结构,众多的参数量,以及在足够大的数据集下进行训练而产生的模型. 在NLP领域,预训练模型往往是语言模型,因为语言模型的训练是无监督的,可以获得大规模语料,同时语言模型又是许多典型NLP任务的基础,如机器翻译,文本生成,阅读理解等,常见的预训练模型有BERT, GPT, roBERTa, transformer-XL等.

微调(Fine-tuning):

根据给定的预训练模型,改变它的部分参数或者为其新增部分输出结构后,通过在小部分数据集上训练,来使整个模型更好的适应特定任务.

微调脚本(Fine-tuning script):

实现微调过程的代码文件。这些脚本文件中,应包括对预训练模型的调用,对微调参数的选定以及对微调结构的更改等,同时,因为微调是一个训练过程,它同样需要一些超参数的设定,以及损失函数和优化器的选取等, 因此微调脚本往往也包含了整个迁移学习的过程.

一般情况下,微调脚本应该由不同的任务类型开发者自己编写,但是由于目前研究的NLP任务类型(分类,提取,生成)以及对应的微调输出结构都是有限的,有些微调方式已经在很多数据集上被验证是有效的,因此微调脚本也可以使用已经完成的规范脚本.

迁移方式

  • 直接使用预训练模型,进行相同任务的处理,不需要调整参数或模型结构,这些模型开箱即用。但是这种情况一般只适用于普适任务, 如:fasttest工具包中预训练的词向量模型。另外,很多预训练模型开发者为了达到开箱即用的效果,将模型结构分各个部分保存为不同的预训练模型,提供对应的加载方法来完成特定目标.
  • 更加主流的迁移学习方式是发挥预训练模型特征抽象的能力,然后再通过微调的方式,通过训练更新小部分参数以此来适应不同的任务。这种迁移方式需要提供小部分的标注数据来进行监督学习.

GLUE数据集合

GLUE由纽约大学, 华盛顿大学, Google联合推出, 涵盖不同NLP任务类型, 截止至2020年1月其中包括11个子任务数据集, 成为衡量NLP研究发展的衡量标准.

GLUE数据集合包含以下数据集

  • CoLA 数据集
  • SST-2 数据集
  • MRPC 数据集
  • STS-B 数据集
  • QQP 数据集
  • MNLI 数据集
  • SNLI 数据集
  • QNLI 数据集
  • RTE 数据集
  • WNLI 数据集
  • diagnostics数据集(官方未完善)

下载方式

  1. ''' Script for downloading all GLUE data.'''
  2. import os
  3. import sys
  4. import shutil
  5. import argparse
  6. import tempfile
  7. import urllib.request
  8. import zipfile
  9. TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
  10. TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
  11. "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
  12. "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
  13. "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
  14. "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
  15. "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
  16. "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
  17. "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
  18. "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
  19. "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
  20. "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
  21. MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
  22. MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
  23. def download_and_extract(task, data_dir):
  24. print("Downloading and extracting %s..." % task)
  25. data_file = "%s.zip" % task
  26. urllib.request.urlretrieve(TASK2PATH[task], data_file)
  27. with zipfile.ZipFile(data_file) as zip_ref:
  28. zip_ref.extractall(data_dir)
  29. os.remove(data_file)
  30. print("\tCompleted!")
  31. def format_mrpc(data_dir, path_to_data):
  32. print("Processing MRPC...")
  33. mrpc_dir = os.path.join(data_dir, "MRPC")
  34. if not os.path.isdir(mrpc_dir):
  35. os.mkdir(mrpc_dir)
  36. if path_to_data:
  37. mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
  38. mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
  39. else:
  40. print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
  41. mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
  42. mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
  43. urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
  44. urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
  45. assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
  46. assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
  47. urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
  48. dev_ids = []
  49. with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
  50. for row in ids_fh:
  51. dev_ids.append(row.strip().split('\t'))
  52. with open(mrpc_train_file, encoding="utf8") as data_fh, \
  53. open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
  54. open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
  55. header = data_fh.readline()
  56. train_fh.write(header)
  57. dev_fh.write(header)
  58. for row in data_fh:
  59. label, id1, id2, s1, s2 = row.strip().split('\t')
  60. if [id1, id2] in dev_ids:
  61. dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
  62. else:
  63. train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
  64. with open(mrpc_test_file, encoding="utf8") as data_fh, \
  65. open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
  66. header = data_fh.readline()
  67. test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
  68. for idx, row in enumerate(data_fh):
  69. label, id1, id2, s1, s2 = row.strip().split('\t')
  70. test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
  71. print("\tCompleted!")
  72. def download_diagnostic(data_dir):
  73. print("Downloading and extracting diagnostic...")
  74. if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
  75. os.mkdir(os.path.join(data_dir, "diagnostic"))
  76. data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
  77. urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
  78. print("\tCompleted!")
  79. return
  80. def get_tasks(task_names):
  81. task_names = task_names.split(',')
  82. if "all" in task_names:
  83. tasks = TASKS
  84. else:
  85. tasks = []
  86. for task_name in task_names:
  87. assert task_name in TASKS, "Task %s not found!" % task_name
  88. tasks.append(task_name)
  89. return tasks
  90. def main(arguments):
  91. parser = argparse.ArgumentParser()
  92. parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
  93. parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
  94. type=str, default='all')
  95. parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
  96. type=str, default='')
  97. args = parser.parse_args(arguments)
  98. if not os.path.isdir(args.data_dir):
  99. os.mkdir(args.data_dir)
  100. tasks = get_tasks(args.tasks)
  101. for task in tasks:
  102. if task == 'MRPC':
  103. format_mrpc(args.data_dir, args.path_to_mrpc)
  104. elif task == 'diagnostic':
  105. download_diagnostic(args.data_dir)
  106. else:
  107. download_and_extract(task, args.data_dir)
  108. if __name__ == '__main__':
  109. sys.exit(main(sys.argv[1:]))

NLP中的常用预训练模型

当下NLP中流行的预训练模型

  • BERT
  • GPT
  • GPT-2
  • Transformer-XL
  • XLNet
  • XLM
  • RoBERTa
  • DistilBERT
  • ALBERT
  • T5
  • XLM-RoBERTa

  • BERT及其变体:
    • bert-base-uncased: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在小写的英文文本上进行训练而得到.
    • bert-large-uncased: 编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共340M参数量, 在小写的英文文本上进行训练而得到.
    • bert-base-cased: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在不区分大小写的英文文本上进行训练而得到.
    • bert-large-cased: 编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共340M参数量, 在不区分大小写的英文文本上进行训练而得到.
    • bert-base-multilingual-uncased: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在小写的102种语言文本上进行训练而得到.
    • bert-large-multilingual-uncased: 编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共340M参数量, 在小写的102种语言文本上进行训练而得到.
    • bert-base-chinese: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在简体和繁体中文文本上进行训练而得到.

  • GPT:
    • openai-gpt: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 由OpenAI在英文语料上进行训练而得到.

  • GPT-2及其变体:
    • gpt2: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共117M参数量, 在OpenAI GPT-2英文语料上进行训练而得到.
    • gpt2-xl: 编码器具有48个隐层, 输出1600维张量, 25个自注意力头, 共1558M参数量, 在大型的OpenAI GPT-2英文语料上进行训练而得到.

  • Transformer-XL:
    • transfo-xl-wt103: 编码器具有18个隐层, 输出1024维张量, 16个自注意力头, 共257M参数量, 在wikitext-103英文语料进行训练而得到.

  • XLNet及其变体:
    • xlnet-base-cased: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在英文语料上进行训练而得到.
    • xlnet-large-cased: 编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共240参数量, 在英文语料上进行训练而得到.

  • XLM:
    • xlm-mlm-en-2048: 编码器具有12个隐层, 输出2048维张量, 16个自注意力头, 在英文文本上进行训练而得到.

  • RoBERTa及其变体:
    • roberta-base: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共125M参数量, 在英文文本上进行训练而得到.
    • roberta-large: 编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共355M参数量, 在英文文本上进行训练而得到.

  • DistilBERT及其变体:
    • distilbert-base-uncased: 基于bert-base-uncased的蒸馏(压缩)模型, 编码器具有6个隐层, 输出768维张量, 12个自注意力头, 共66M参数量.
    • distilbert-base-multilingual-cased: 基于bert-base-multilingual-uncased的蒸馏(压缩)模型, 编码器具有6个隐层, 输出768维张量, 12个自注意力头, 共66M参数量.

  • ALBERT:
    • albert-base-v1: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共125M参数量, 在英文文本上进行训练而得到.
    • albert-base-v2: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共125M参数量, 在英文文本上进行训练而得到, 相比v1使用了更多的数据量, 花费更长的训练时间.

  • T5及其变体:
    • t5-small: 编码器具有6个隐层, 输出512维张量, 8个自注意力头, 共60M参数量, 在C4语料上进行训练而得到.
    • t5-base: 编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共220M参数量, 在C4语料上进行训练而得到.
    • t5-large: 编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共770M参数量, 在C4语料上进行训练而得到.

  • XLM-RoBERTa及其变体:
    • xlm-roberta-base: 编码器具有12个隐层, 输出768维张量, 8个自注意力头, 共125M参数量, 在2.5TB的100种语言文本上进行训练而得到.
    • xlm-roberta-large: 编码器具有24个隐层, 输出1027维张量, 16个自注意力头, 共355M参数量, 在2.5TB的100种语言文本上进行训练而得到.

  • 预训练模型说明:
    • 所有上述预训练模型及其变体都是以transformer为基础,只是在模型结构如神经元连接方式,编码器隐层数,多头注意力的头数等发生改变,这些改变方式的大部分依据都是由在标准数据集上的表现而定,因此,对于我们使用者而言,不需要从理论上深度探究这些预训练模型的结构设计的优劣,只需要在自己处理的目标数据上,尽量遍历所有可用的模型对比得到最优效果即可.

加载和使用预训练模型

步骤

  • 第一步: 确定需要加载的预训练模型并安装依赖包.
  • 第二步: 加载预训练模型的映射器tokenizer.
  • 第三步: 加载带/不带头的预训练模型.
  • 第四步: 使用模型获得输出结果.

确定需要加载的预训练模型并安装依赖包

  • 需要加载的模型是BERT的中文模型: bert-base-chinese
  • 在使用工具加载模型前需要安装必备的依赖包:
  1. pip install tqdm boto3 requests regex sentencepiece sacremoses

加载预训练模型的映射器tokenizer

  1. import torch
  2. # 预训练模型来源
  3. source = 'huggingface/pytorch-transformers'
  4. # 选定加载模型的哪一部分, 这里是模型的映射器
  5. part = 'tokenizer'
  6. # 加载的预训练模型的名字
  7. model_name = 'bert-base-chinese'
  8. tokenizer = torch.hub.load(source, part, model_name)

加载带/不带头的预训练模型

  • 加载预训练模型时我们可以选择带头或者不带头的模型
  • 这里的’头’是指模型的任务输出层, 选择加载不带头的模型, 相当于使用模型对输入文本进行特征表示.
  • 选择加载带头的模型时, 有三种类型的’头’可供选择, modelWithLMHead(语言模型头), modelForSequenceClassification(分类模型头), modelForQuestionAnswering(问答模型头)
  • 不同类型的’头’, 可以使预训练模型输出指定的张量维度. 如使用’分类模型头’, 则输出尺寸为(1,2)的张量, 用于进行分类任务判定结果.
  1. # 加载不带头的预训练模型
  2. part = 'model'
  3. model = torch.hub.load(source, part, model_name)
  4. # 加载带有语言模型头的预训练模型
  5. part = 'modelWithLMHead'
  6. lm_model = torch.hub.load(source, part, model_name)
  7. # 加载带有类模型头的预训练模型
  8. part = 'modelForSequenceClassification'
  9. classification_model = torch.hub.load(source, part, model_name)
  10. # 加载带有问答模型头的预训练模型
  11. part = 'modelForQuestionAnswering'
  12. qa_model = torch.hub.load(source, part, model_name)

使用模型获得输出结果

  • 使用不带头的模型进行输出:
  1. # 输入的中文文本
  2. input_text = "人生该如何起头"
  3. # 使用tokenizer进行数值映射
  4. indexed_tokens = tokenizer.encode(input_text)
  5. # 打印映射后的结构
  6. print("indexed_tokens:", indexed_tokens)
  7. # 将映射结构转化为张量输送给不带头的预训练模型
  8. tokens_tensor = torch.tensor([indexed_tokens])
  9. # 使用不带头的预训练模型获得结果
  10. with torch.no_grad():
  11. encoded_layers, _ = model(tokens_tensor)
  12. print("不带头的模型输出结果:", encoded_layers)
  13. print("不带头的模型输出结果的尺寸:", encoded_layers.shape)
  • 输出效果:
  1. # tokenizer映射后的结果, 101和102是起止符,
  2. # 中间的每个数字对应"人生该如何起头"的每个字.
  3. indexed_tokens: [101, 782, 4495, 6421, 1963, 862, 6629, 1928, 102]
  4. 不带头的模型输出结果: tensor([[[ 0.5421, 0.4526, -0.0179, ..., 1.0447, -0.1140, 0.0068],
  5. [-0.1343, 0.2785, 0.1602, ..., -0.0345, -0.1646, -0.2186],
  6. [ 0.9960, -0.5121, -0.6229, ..., 1.4173, 0.5533, -0.2681],
  7. ...,
  8. [ 0.0115, 0.2150, -0.0163, ..., 0.6445, 0.2452, -0.3749],
  9. [ 0.8649, 0.4337, -0.1867, ..., 0.7397, -0.2636, 0.2144],
  10. [-0.6207, 0.1668, 0.1561, ..., 1.1218, -0.0985, -0.0937]]])
  11. # 输出尺寸为1x9x768, 即每个字已经使用768维的向量进行了表示,
  12. # 我们可以基于此编码结果进行接下来的自定义操作, 如: 编写自己的微调网络进行最终输出.
  13. 不带头的模型输出结果的尺寸: torch.Size([1, 9, 768])

  • 使用带有语言模型头的模型进行输出:
  1. # 使用带有语言模型头的预训练模型获得结果
  2. with torch.no_grad():
  3. lm_output = lm_model(tokens_tensor)
  4. print("带语言模型头的模型输出结果:", lm_output)
  5. print("带语言模型头的模型输出结果的尺寸:", lm_output[0].shape)

  • 输出效果:
  1. 带语言模型头的模型输出结果: (tensor([[[ -7.9706, -7.9119, -7.9317, ..., -7.2174, -7.0263, -7.3746],
  2. [ -8.2097, -8.1810, -8.0645, ..., -7.2349, -6.9283, -6.9856],
  3. [-13.7458, -13.5978, -12.6076, ..., -7.6817, -9.5642, -11.9928],
  4. ...,
  5. [ -9.0928, -8.6857, -8.4648, ..., -8.2368, -7.5684, -10.2419],
  6. [ -8.9458, -8.5784, -8.6325, ..., -7.0547, -5.3288, -7.8077],
  7. [ -8.4154, -8.5217, -8.5379, ..., -6.7102, -5.9782, -7.6909]]]),)
  8. # 输出尺寸为1x9x21128, 即每个字已经使用21128维的向量进行了表示,
  9. # 同不带头的模型一样, 我们可以基于此编码结果进行接下来的自定义操作, 如: 编写自己的微调网络进行最终输出.
  10. 带语言模型头的模型输出结果的尺寸: torch.Size([1, 9, 21128])

  • 使用带有分类模型头的模型进行输出:
  1. # 使用带有分类模型头的预训练模型获得结果
  2. with torch.no_grad():
  3. classification_output = classification_model(tokens_tensor)
  4. print("带分类模型头的模型输出结果:", classification_output)
  5. print("带分类模型头的模型输出结果的尺寸:", classification_output[0].shape)

  • 输出效果:
  1. 带分类模型头的模型输出结果: (tensor([[-0.0649, -0.1593]]),)
  2. # 输出尺寸为1x2, 可直接用于文本二分问题的输出
  3. 带分类模型头的模型输出结果的尺寸: torch.Size([1, 2])

  • 使用带有问答模型头的模型进行输出:
  1. # 使用带有问答模型头的模型进行输出时, 需要使输入的形式为句子对
  2. # 第一条句子是对客观事物的陈述
  3. # 第二条句子是针对第一条句子提出的问题
  4. # 问答模型最终将得到两个张量,
  5. # 每个张量中最大值对应索引的分别代表答案的在文本中的起始位置和终止位置.
  6. input_text1 = "我家的小狗是黑色的"
  7. input_text2 = "我家的小狗是什么颜色的呢?"
  8. # 映射两个句子
  9. indexed_tokens = tokenizer.encode(input_text1, input_text2)
  10. print("句子对的indexed_tokens:", indexed_tokens)
  11. # 输出结果: [101, 2769, 2157, 4638, 2207, 4318, 3221, 7946, 5682, 4638, 102, 2769, 2157, 4638, 2207, 4318, 3221, 784, 720, 7582, 5682, 4638, 1450, 136, 102]
  12. # 用0,1来区分第一条和第二条句子
  13. segments_ids = [0]*11 + [1]*14
  14. # 转化张量形式
  15. segments_tensors = torch.tensor([segments_ids])
  16. tokens_tensor = torch.tensor([indexed_tokens])
  17. # 使用带有问答模型头的预训练模型获得结果
  18. with torch.no_grad():
  19. start_logits, end_logits = qa_model(tokens_tensor, token_type_ids=segments_tensors)
  20. print("带问答模型头的模型输出结果:", (start_logits, end_logits))
  21. print("带问答模型头的模型输出结果的尺寸:", (start_logits.shape, end_logits.shape))

  • 输出效果:
  1. 句子对的indexed_tokens: [101, 2769, 2157, 4638, 2207, 4318, 3221, 7946, 5682, 4638, 102, 2769, 2157, 4638, 2207, 4318, 3221, 784, 720, 7582, 5682, 4638, 1450, 136, 102]
  2. 带问答模型头的模型输出结果: (tensor([[ 0.2574, -0.0293, -0.8337, -0.5135, -0.3645, -0.2216, -0.1625, -0.2768,
  3. -0.8368, -0.2581, 0.0131, -0.1736, -0.5908, -0.4104, -0.2155, -0.0307,
  4. -0.1639, -0.2691, -0.4640, -0.1696, -0.4943, -0.0976, -0.6693, 0.2426,
  5. 0.0131]]), tensor([[-0.3788, -0.2393, -0.5264, -0.4911, -0.7277, -0.5425, -0.6280, -0.9800,
  6. -0.6109, -0.2379, -0.0042, -0.2309, -0.4894, -0.5438, -0.6717, -0.5371,
  7. -0.1701, 0.0826, 0.1411, -0.1180, -0.4732, -0.1541, 0.2543, 0.2163,
  8. -0.0042]]))
  9. # 输出为两个形状1x25的张量, 他们是两条句子合并长度的概率分布,
  10. # 第一个张量中最大值所在的索引代表答案出现的起始索引,
  11. # 第二个张量中最大值所在的索引代表答案出现的终止索引.
  12. 带问答模型头的模型输出结果的尺寸: (torch.Size([1, 25]), torch.Size([1, 25]))