1、基本思想
- 不同领域的语料有不同的特点,词汇表也大相径庭。垂直领域中会包含相对一部分的专业词汇(有的专业词汇甚至会比较长),专业词汇的词干部分可能在 token 词表中不存在,但这并不影响其在 tokenize 过程中产生
[UNK]
,只是可能被更多个 BPE 词片段切分为更零散的 token 列表而已。然而,专业词汇被切分得过于零散可能不利于模型的语义理解。 - 首先考虑为什么要进行 BPE 构建 token 词表。
- 主要是为了在保持合理语义的情况下,减少模型训练过程中的
[UNK]
,同时使模型参数得到合理的控制(词表不能过大);合理的 BPE 构建得到的 token 词表在对一个句子进行 tokenize 时,能把一个长度适中的句子切分为长度适中的 token 列表,使得模型训练过程中设置max_seq_len
时对应的结果是一个语义较为完整的片段(可以是连续的多个句子)。 - 考虑以下两种极端的 BPE 构建 token 词表的情况:
- 1)将每个英文字母作为 BEP 的 token 词表
- 此时 uncased 情况下 26 个英文字母即可表达所有英文单词(忽略标点符号等特殊符号,当然,考虑这些对于 token 词表的扩大幅度也极为有限,此处仅做简单原理说明,故不考虑),实现不会产生
[UNK]
的效果,此时 token 词表极小。 - 但由于词表中的每个 token 均为单个字符,表达一个简短的句子时也要求
max_seq_len
要足够大(否则构成的一个片段并不包含太完整的语义信息)。- 保守估计,原先正常 BEP 的 token 词表下,结合 max_seq_len 设置为 256 时,在这种特殊的单字符 token 词表下,对应的 max_seq_len 可能就需要为
[256*5, 256*8]
(此处假设一个句子中的英文单词平均由 5 到 8 个字符构成)
- 保守估计,原先正常 BEP 的 token 词表下,结合 max_seq_len 设置为 256 时,在这种特殊的单字符 token 词表下,对应的 max_seq_len 可能就需要为
- 相比之下,总参数似乎是减少的了(因为
max_seq_len
长度增大的倍数远远小于 token 词表减少的倍数),但这种模式下,每个 token 几乎不能重点突出表示相应的语义信息,训练得到的模型因此也很难实现不同语义的辨别判断。- 通过该逻辑可以看出,含有特定语义信息的合理的 token 词表对于训练模型有着重大意义。
- 此时 uncased 情况下 26 个英文字母即可表达所有英文单词(忽略标点符号等特殊符号,当然,考虑这些对于 token 词表的扩大幅度也极为有限,此处仅做简单原理说明,故不考虑),实现不会产生
- 2)使用独立的英文单词直接构成 BPE 中的 token 词表
- 由于同一英文单词可能存在很多变种(负数、所有格等),同时存在许多连接符(如
-
)拼接构成的组合词,甚至一些 NER 词汇,通过该方式构造 token 词表,难免不能包括全、且构造成的词表会比较庞大,可能相比于正常 BEP 的 token 词表将增大十来倍(但也不能保证囊括所有词汇,难免还是会存在[UNK]
)。 - 同时,在该方式的情况下,为了正常表示一个语义完整的片段,
max_seq_len
参数可能会相比于正常 BEP 的 token 词表情况下略有减小,但幅度不会太大。因此,该方式会极大增大模型参数量。 - 尽管模型参数大大增加,但太细化的单词作为 token 对于模型的训练也不一定起到正面作用,比如词根相同、但后缀不同的词,基于该方式进行词的数字化表示时,相似的词并不存在太多的联系。
- 由于同一英文单词可能存在很多变种(负数、所有格等),同时存在许多连接符(如
- 1)将每个英文字母作为 BEP 的 token 词表
- 主要是为了在保持合理语义的情况下,减少模型训练过程中的
- 而基于 BPE 处理得到的 token 词表,词根可得到比较好的切分,因此在对词根相同的词进行数字化时,是有相同成分存在的。同时为了减少
[UNK]
,切分完词根后也会补充更为细粒度的单词片段,确保相同词根、不同词均可得到恰当的表示。 但是 BERT 的 pre-train 结果并非垂直于某领域训练得到,在具体垂直领域中的某些专业词汇对应的词根可能在 token 词表中并不能找到相应的结果,因此缺少特定的完整语义片段 token。因此,在训练垂直领域的 BERT 模型时,基于特定的语料对原 BERT 中的 token 词表做相应的扩充,是一个理论上较为有效的实验方向(待实验分析证明)。
2、实验方法
(1)获取语料库中的单词词干词汇表(json 存储)
统计语料库中 word 以及其对应的词频(过滤掉标点、停用词、字符长度小于某特定值的 word),并逐步做以下处理(随后存储于 json 格式文件,作为通用领域词汇表):
- 基于 spacy 获取 word 的 lemma 属性
- 恢复为词的原有形式,可大大减少词汇量,因为同类词会有多种不同形式
- 该处理的必要性分析:垂直领域的词汇中的专业术语可能是通过连接符
-
拼接的组合词,如formaldehyde-induced-fluorescence
,如果忽略该步骤直接通过 nltk 获取词干,则会得到的结果并不理想。 - 经过该处理后,还并不是真正的单词词干部分,如
successfully
经过处理后仍为原词successfully
。
- 基于 nltk 获取经过 spacy 处理后的词干
- 可以将诸如
successfully
处理为success
```python import json import io from collections import Counter import re import spacy from pysenal import get_chunk
- 可以将诸如
- 基于 spacy 获取 word 的 lemma 属性
def get_stopwords(): from gensim.parsing.preprocessing import STOPWORDS return list(STOPWORDS)
def remove_en_punct(text):
# jy: 注意, "." 和 "?" 要进行转义, 因为其在 re 中都有特殊含义;
text = re.sub('\.|,|;|:|\?|!', " ", text)
return text
def get_uncasedWord_count(f_name, f_out_name, batch_word_size=50000): “”” f_name: 语料库文件 语料库文件尽量不要太大, 建议 1G 左右, 如果大文件可切分并发处理, 最终对输出的 json 文件进行合并即可; f_out_name: 输出的 json 文件;
该函数实现基本的过滤, 并进行 lemmatize;
"""
with io.open(f_name, "r", encoding="utf8") as f_, \
io.open(f_out_name, "w", encoding="utf8") as f_out:
str_in = f_.read()
str_in = str_in.lower()
str_in = str_in.replace("\n", " ")
str_in = remove_en_punct(str_in)
ls_word = str_in.split()
dict_word_count = dict(Counter(ls_word))
ls_stopwords = get_stopwords()
ls_word = list(dict_word_count.keys())
for word in ls_word:
if word in ls_stopwords or len(word) < 5:
del dict_word_count[word]
dict_res = {}
nlp = spacy.load("en_core_web_sm")
ls_all_word = []
ls_all_count = []
for word, count in dict_word_count.items():
ls_all_word.append(word)
ls_all_count.append(count)
for ls_word, ls_count in zip(get_chunk(ls_all_word, batch_word_size),
get_chunk(ls_all_count, batch_word_size)):
str_words = " || ".join(ls_word)
print(len(str_words))
# jy: 传入 nlp() 中的文本长度不能超 1000000, 否则会报错;
# 同时, 由于传入短文本或长文本对执行效率的影响并不是特别显著, 故
# 为了提高效率, 尽可能一次性拼接成较长的文本长度(并确保不超长度限制)
doc = nlp(str_words)
ls_lemma = [i.lemma_ for i in doc]
str_lemma = " ".join(ls_lemma)
ls_word_lemma = str_lemma.split(" || ")
assert len(ls_word_lemma) == len(ls_count)
for i in range(len(ls_count)):
word_lemmas = ls_word_lemma[i].split()
word_lemmas = [i for i in word_lemmas if i not in ls_stopwords \
and len(i) >=5]
for lemma in word_lemmas:
if lemma not in dict_res:
dict_res[lemma] = ls_count[i]
else:
dict_res[lemma] += ls_count[i]
json.dump(dict_res, f_out)
“”” f_name = “wiki1m_for_simcse.txt” f_out_name = “uncased-common-words.json” get_uncasedWord_count(f_name, f_out_name) “””
def is_filter_word(word, num_ratio=0.3): str_num = “”.join(re.findall(r”\d+”, word)) if len(str_num) / len(word) >= num_ratio: return True return False
def get_special_wordStem(f_common_name, f_uncommon_name, f_vocab_name, f_special_name, min_count=5): “”” f_common_name: 通用领域语料经过 spacy 进行 lemmetize 后的词表文件; f_uncommon_name: 垂直领域语料经过 spacy 进行 lemmetize 后的词表文件; f_vocab_name: BPE token 词表文件(每行对应一个词); f_special_name: 存储特有词干的 json 文件; min_count: 出现次数少于该数值的 token 均会被过滤掉; “”” with io.open(f_common_name, “r”) as f_common, \ io.open(f_uncommon_name, “r”) as f_uncommon, \ io.open(f_vocab_name, “r”) as f_vocab, \ io.open(f_special_name, “w”) as f_out: dict_common = json.load(f_common) dict_uncommon = json.load(f_uncommon) for word, count in dict_common.items(): if word in dict_uncommon: del dict_uncommon[word]
set_token = set()
for line in f_vocab:
token = line.strip()
if "[unused" in token:
continue
set_token.add(token)
dict_res = {}
# jy: 过滤掉出现次数少于 min_count 的词干;
for word, count in dict_uncommon.items():
# jy: 通常完整的 word 很少会与 set_token 中有重合(此处可做进一步处理);
if count > min_count and not is_filter_word(word, num_ratio=0.1) \
and word not in set_token:
dict_res[word] = count
print(len(dict_res))
json.dump(dict_res, f_out)
f_common_name = “common-words-uncased.json” f_uncommon_name = “pharm-words-uncased.json” f_vocab_name = “/home/huangjiayue/04_SimCSE/jy_model/bert-base-uncased/vocab.txt” f_special_name = “special_wordStem.json” get_special_wordStem(f_common_name, f_uncommon_name, f_vocab_name, f_special_name, min_count=200)
- 获取垂直领域词干后,可以进一步去除`vocab.txt`中的前缀或后缀部分:
```python
def get_dict_len_word(ls_word):
"""
ls_word: token 词列表
返回结果: {len(token): 相同长度的 token 集合}
"""
dict_ = {}
for word in ls_word:
if len(word) not in dict_:
dict_[len(word)] = set([word])
else:
dict_[len(word)].add(word)
return dict_
def analysis_vocab_txt(f_name):
"""
f_name: token 词表文件(如: vocab.txt)
返回结果: 前缀字典和后缀字典(字典形式如: {len(token): 相同长度的 token 集合}),
和 token 列表;
前缀 token 为词表文件中字符个数大于 4 且不以 "#" 开头的 token
后缀 token 为词表文件中字符个数大于 5 且以 "#" 开头的 token
"""
with open(f_name, "r", encoding="utf8") as f_:
ls_token = [i.strip() for i in f_.read().split("\n")]
ls_prefix = [i for i in ls_token if "#" not in i and len(i) > 4]
ls_suffix = [i.lstrip("#") for i in ls_token if "#" in i and len(i) >= 6]
dict_prefix = get_dict_len_word(ls_prefix)
dict_suffix = get_dict_len_word(ls_suffix)
return dict_prefix, dict_suffix, ls_token
def get_new_token(f_name, ls_special_token):
"""
f_name: token 词表文件(如: vocab.txt)
ls_special_token: 垂直领域特有 token;
"""
dict_prefix, dict_suffix, ls_vocab_token = analysis_vocab_txt(f_name)
min_len_prefix = min(dict_prefix.keys())
min_len_suffix = min(dict_suffix.keys())
ls_token = []
for token in ls_special_token:
is_prefix = False
# jy: 倒叙遍历, 确保最长前缀优先匹配;
for i in range(min(17, len(token)), min_len_prefix, -1):
prefix_ = token[: i+1]
if prefix_ in dict_prefix[i]:
rm_prefix = token.lstrip(prefix_)
if rm_prefix:
ls_token.append(rm_prefix)
is_prefix = True
print("prefix== %s, token == %s, rm_prefix == %s" % (
prefix_, token, rm_prefix))
# jy: 确保一个最长前缀匹配到后即使用该前缀处理后退出;
break
if not is_prefix:
ls_token.append(token)
is_suffix = False
# jy: 倒叙遍历, 确保最长后缀优先匹配;
for i in range(min(10, len(token)), min_len_suffix, -1):
suffix_ = token[-i:]
if suffix_ in dict_suffix[i]:
rm_suffix = token.rstrip(suffix_)
if rm_suffix:
ls_token.append(rm_suffix)
is_suffix = True
print("suffix== %s, token == %s, rm_suffix == %s" % (
suffix_, token, rm_suffix))
# jy: 确保一个最长后缀匹配到后即使用该后缀处理后退出;
break
if not is_suffix:
ls_token.append(token)
# jy: 去重;
ls_dedup_token = []
for tk in ls_token:
if tk not in ls_dedup_token and tk not in ls_vocab_token:
ls_dedup_token.append(tk)
return ls_dedup_token
(2)获取垂直领域的专有词干词汇表
- 准备通用领域语料、垂直领域语料,基于以上方式获取词干词汇表,得到两个 json 文件。
- 对两个 json 文件中的词干词汇进行去重,获取垂直领域的独有词干词汇表,并基于相应词频取 top-k 个扩充到 BERT 的 token 词表中(扩充方式见上一篇文章说明)。