def forward_one_step(self, X, ht_enc, H_enc, skip_mask):
pad = self._kernel_size - 1
WX = self.W(X)[:, :, -pad-1, None]
Vh = self.V(ht_enc)
Vh, WX = functions.broadcast(functions.expand_dims(Vh, axis=2), WX)
# f-pooling
Z, F, O = functions.split_axis(WX + Vh, 3, axis=1)
Z = functions.tanh(Z)
F = self.zoneout(F)
O = functions.sigmoid(O)
T = Z.shape[2]
# compute ungated hidden states
for t in xrange(T):
z = Z[..., t]
f = F[..., t]
if self.contexts is None:
ct = (1 - f) * z
self.contexts = [ct]
else:
ct = f * self.contexts[-1] + (1 - f) * z
self.contexts.append(ct)
if skip_mask is not None:
assert skip_mask.shape[1] == H_enc.shape[2]
softmax_bias = (skip_mask == 0) * -1e6
# compute attention weights (eq.8)
H_enc = functions.swapaxes(H_enc, 1, 2)
for t in xrange(T):
ct = self.contexts[t - T]
bias = 0 if skip_mask is None else softmax_bias[..., None] # to skip PAD
mask = 1 if skip_mask is None else skip_mask[..., None] # to skip PAD
alpha = functions.batch_matmul(H_enc, ct) + bias
alpha = functions.softmax(alpha) * mask
alpha = functions.broadcast_to(alpha, H_enc.shape) # copy
kt = functions.sum(alpha * H_enc, axis=1)
ot = O[..., t]
self.ht = ot * self.o(functions.concat((kt, ct), axis=1))
if self.H is None:
self.H = functions.expand_dims(self.ht, 2)
else:
self.H = functions.concat((self.H, functions.expand_dims(self.ht, 2)), axis=2)
return self.H
评论列表
文章目录