输入x
- 形状 (128, 512) —> (batch_size, maxlen)
x_mask
- (128, 512) 有单词的位置是1, 没有的位置是0
- x_mask 形状 (128, 512, 1) —-> (batch_size, maxlen, 1)
输入y
- 形状 (128, seq_len) —> (batch_size, seq_len) seq_len:一个batch中的最大长度
- seq_len 变长
Embedding层
- x (128, 512) —-> (128, 512, 256)
y (128, 512) —-> (128, 512, 256)
编码器Encoder
输入x, x_mask
x (128, 512, 256)
- x_mask (128, 512, 1)
256 —> 128 ==> embed_size —-> hidden_size
双向LSTM
设置隐层维度 z_dim/2 = 64
- 双向LSTM 因此输出维度 128
- LSTM out (batch_size, maxlen, 128)
- out * x_mask
- (128, 512, 128) (128, 512, 1) 广播
- 输出 (128, 512, 128) —> (batch_size, maxlen, hidden_size)
- (128, 512, 1) 为0的值,将 整个 256 全设置为0
输出 (batch_size, seq_len, 128) —>(128, seq_len, 128)
LayerNormlization
解码器器Decoder
输入 y, y_mask
- y (128, seq_len, 256)
-
单向LSTM
设置隐层维度 z_dim = 128
- 双向LSTM 因此输出维度 128
输出 (batch_size, seq_len, 128) —>(128, seq_len, 128)
LayerNormlization
Attention
输入 Q、K、V、x_mask —-> (y, x, x, x_mask)
- n_heads —-> 8
-
1线性层
输入 (batch_size, len, hidden_size) —-> 输出 (batch_size, len, head_size * n_head)
- qw —-> (batch_size, seq_len, head_size * n_head) —> (128, y_seq_len, 128)
- kw ——> (batch_size, 512, head_size * n_head) —> (128, 512, 128)
vw ——> (batch_size, 512, head_size * n_head) —> (128, 512, 128)
2形状变换
qw —-> (batch_size, seq_len, n_head, head_size) —> (128, y_len, 8, 16)
- kw —-> (batch_size, 512, n_head, head_size) —>(128, 512, 8, 16)
vw —-> (batch_size, 512, n_head, head_size) —>(128, 512, 8, 16)
3维度置换
(qw, (0, 2, 1, 3))
- qw —-> (batch_size, n_head, seq_len, head_size) —> (128, 8, y_len, 16)
- kw —-> (batch_size, n_head, 512, head_size) —>(128, 8, 512, 16)
vw —-> (batch_size, n_head, 512, head_size) —>(128, 8, 512, 16)
4Dot-Product
输入 qw (128, 8, y_len, 16) kw (128, 8, 512, 16)
- a = qw^T dot kw
- PyTorch torch.matmul(q, v.transpose(-1, -2), 第一个维度是batch交换后面两个
- K.batch_dot(qw, kw, [3, 3]) —> (batch_size, head_size, y_len, x_len) —> (128, 8, y_len, 512)
- (128, 8, y_len, 512)
-
5scaled-dot Product
a = a/np.sqrt(d_k)
- d_k就是头的size=16 np.sqrt(d_k) = 4
-
6a = K.permute_dimensions(a, (0, 3, 2, 1))
输入 (128, 8, y_len, 512)
输出 (batch_size, 512, y_len, 8)
7mask
y_len 一个batch中 最大句子长度
- 输入 a (batch_size, 512, y_len, 16) x_mask (128, 512, 1)
for _ in range(K.ndim(x) - K.ndim(mask)):
mask = K.expand_dims(mask, K.ndim(mask))
- K.ndim(x) = 4, K.ndim(mask) = 3
- 在K.ndim(mask)=3 上再增加一个维度,K.expand_dims(mask, K.ndim(mask))
- mask -> mask —> (128, 512, 1) ——> (128, 512, 1, 1)
- return x - (1 - mask) * 1e10
- 输出维度 (batch_size, 8, 512, y_len)
- mask 中0的表示被mask
- 对于a 将mask为0的位置,设置为极大负数
输出 (batch_size, 512, y_len, 8)
8 a = K.permute_dimensions(a, (0, 3, 2, 1))
(batch_size, 512, y_len, 8)—>(batch_size, 8, y_len, 512)
(128, 512, y_len, 8) —-> (128, 8, y_len, 512)
8softmax
输入 a (128, 8, y_len, 512)
- y_len × 512
0 ……. 511
15 ….. 511
每一行进行求softmax,表示 x 对于 y_len上每个单词的贡献
9 softmax(score) × value
输入 a (batch_size, n_head, y_len, 512) —>(128, 8, y_len, 512)
输入 vw (batch_size, n_head, 512, head_size) —>(128, 8, 512, 16)
(?, 8, ?, 16)
o = K.batch_dot(a, vw, [3, 2])scores 是一个打分矩阵 (y_len, x_len)
- value 代表x x_len
- 输出 y_len, 表示 x 在score上求解 得到的输出 y
输出 (128, 8, y_len, 16)
- o = K.permute_dimensions(o, (0, 2, 1, 3)) —> (128, y_len, 8, 16)
- o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim)) —> (128, y_len, 8*16)
9.1 mask
o = self.mask(o, q_mask, ‘mul’)
输出 (128, y_len, 128)
10 残差网络
- xy = Attention(8, 16)([y, x, x, x_mask])
(128, y_len, 128)
- y -> (batch_size, y_len, hidden_size) —> (128, y_len, 128)
xy = Concatenate()([y, xy])
- 输出 (128, y_len, 256)
11 输出分类
xy = Dense(char_size)(xy)
xy = Activation('relu')(xy)
xy = Dense(len(chars) + 4)(xy)
xy = Activation('softmax')(xy)
Dense char_size 就是embed_size = 256
(128, y_len, 256) —> (128, y_len, 256)
Activation relu
- 激活,将输出的负数置0
Dense(len(chars)+4)
- 将输出映射到词典上
(128, y_len, 256) —> (128, y_len, VACAB_SIZE)
Softmax
- 得到在词典上每个单词取值的概率
12 损失函数
将输出看作在整个词典上的分类
sparse_categorical_crossentropy 数字编码
[2, 0, 1]
- categorical_crossentropy one-hot编码
[0, 0, 1], [1, 0, 0], [0, 1, 0]
cross_entropy = K.sparse_categorical_crossentropy(y_in[:, 1:], xy[:, :-1])
cross_entropy = K.sum(cross_entropy * y_mask[:, 1:, 0]) / K.sum(y_mask[:, 1:, 0])
- y_in 原始的输入 (batch_size, y_len)
- xy —> (batch_size, y_len, VOCAB_SIZE)