参考与主要要点:
- https://cloud.tencent.com/developer/article/1387666 实现
- https://github.com/keras-team/keras-contrib crf的keras库
- https://docs.floydhub.com/guides/environments/ 注意tensorflow和keras版本问题
- https://github.com/jiesutd/LatticeLSTM 数据下载地址
数据样例
中 B-LOC
国 E-LOC
的 O
天 B-PER
安 I-PER
门 E-PER
我 O
跟 O
他 O
谈 O
笑 O
风 O
生 O
代码
数据处理
import numpy
from collections import Counter
from keras.preprocessing.sequence import pad_sequences
import pickle
import platform
import sys
sys.path.append('/tf/keras/keras-contrib')
def _parse_data(fh):
# in windows the new line is '\r\n\r\n' the space is '\r\n' . so if you use windows system,
# you have to use recorsponding instructions
if platform.system() == 'Windows':
# split_text = '\r\n' #linux
split_text = '\n' #windows
else:
split_text = '\n'
raw_Cropat = fh.read().decode('utf-8')
data = [[row.split() for row in sample.split(split_text)] for
sample in
raw_Cropat.strip().split(split_text + split_text)]
fh.close()
return data
def _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False):
if maxlen is None:
maxlen = max(len(s) for s in data)
word2idx = {w:i for i, w in enumerate(vocab)}
x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to <unk> (index 1) if not in vocab
y_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data]
x = pad_sequences(x, maxlen) # left padding
y_chunk = pad_sequences(y_chunk, maxlen, value=-1)
if onehot:
y_chunk = numpy.eye(len(chunk_tags), dtype='float32')[y_chunk]
else:
y_chunk = numpy.expand_dims(y_chunk, 2)
return x, y_chunk
def process_data(data, vocab, maxlen=100):
word2idx = dict((w, i) for i, w in enumerate(vocab))
x = [word2idx.get(w[0].lower(), 1) for w in data]
length = len(x)
x = pad_sequences([x], maxlen) # left padding
return x, length
def load_data():
train = _parse_data(open('demo.train.char', 'rb'))
test = _parse_data(open('demo.test.char', 'rb'))
word_counts = Counter(row[0].lower() for sample in train for row in sample)
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
chunk_tags = list(set([ line[1] for oneDev in train for line in oneDev]))
# save initial config data
with open('config.pkl', 'wb') as outp:
pickle.dump((vocab, chunk_tags), outp)
train = _process_data(train, vocab, chunk_tags)
test = _process_data(test, vocab, chunk_tags)
return train, test, (vocab, chunk_tags)
搭建模型
from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM
from keras_contrib.layers import CRF
import pickle
EMBED_DIM = 200
BiRNN_UNITS = 200
def create_model(train=True):
if train:
(train_x, train_y), (test_x, test_y), (vocab, chunk_tags) = load_data()
else:
with open('model/config.pkl', 'rb') as inp:
(vocab, chunk_tags) = pickle.load(inp)
model = Sequential()
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
crf = CRF(len(chunk_tags), sparse_target=True)
model.add(crf)
model.summary()
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
if train:
return model, (train_x, train_y), (test_x, test_y)
else:
return model, (vocab, chunk_tags)
if __name__=="__main__":
EPOCHS = 10
model, (train_x, train_y), (test_x, test_y) = create_model()
# train model
model.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y])
model.save('model/crf.h5')
预测
import numpy as np
with open('config.pkl', 'rb') as inp:
(vocab, chunk_tags) = pickle.load(inp)
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
sequence, length = process_data(predict_text, vocab)
model.load_weights('crf.h5')
raw = model.predict(sequence)[0][-length:]
result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in result]
per, loc, org = '', '', ''
for s, t in zip(predict_text, result_tags):
if t in ('B-PER', 'M-PER', 'S-PER','E-PER'):
per += ' ' + s if (t == 'B-PER') else s
if t in ('B-ORG', 'M-ORG', 'S-ORG','E-ORG'):
org += ' ' + s if (t == 'B-ORG') else s
if t in ('B-LOC', 'M-LOC', 'S-LOC','E-LOC'):
loc += ' ' + s if (t == 'B-LOC') else s
print(['person:' + per, 'location:' + loc, 'organzation:' + org])
报错处理
- TypeError: Tensors in list passed to ‘values’ of ‘ConcatV2’ Op have types [bool, float32] that don’t all match.
删除:Embedding层的 mask_zero=True
mask_zero: 是否把 0 看作为一个应该被遮蔽的特殊的 “padding” 值。 这对于可变长的循环神经网络层 十分有用。 如果设定为 True
,那么接下来的所有层都必须支持 masking,否则就会抛出异常。 如果 mask_zero 为 True
,作为结果,索引 0 就不能被用于词汇表中 (input_dim 应该与 vocabulary + 1 大小相同)。