def pool(self, WX, skip_mask=None):
Z, F, O, I = None, None, None, None
# f-pooling
if len(self._pooling) == 1:
assert len(WX) == 2
Z, F = WX
Z = functions.tanh(Z)
F = self.zoneout(F)
# fo-pooling
if len(self._pooling) == 2:
assert len(WX) == 3
Z, F, O = WX
Z = functions.tanh(Z)
F = self.zoneout(F)
O = functions.sigmoid(O)
# ifo-pooling
if len(self._pooling) == 3:
assert len(WX) == 4
Z, F, O, I = WX
Z = functions.tanh(Z)
F = self.zoneout(F)
O = functions.sigmoid(O)
I = functions.sigmoid(I)
assert Z is not None
assert F is not None
T = Z.shape[2]
for t in xrange(T):
zt = Z[..., t]
ft = F[..., t]
ot = 1 if O is None else O[..., t]
it = 1 - ft if I is None else I[..., t]
xt = 1 if skip_mask is None else skip_mask[:, t, None] # will be used for seq2seq to skip PAD
if self.ct is None:
self.ct = (1 - ft) * zt * xt
else:
self.ct = ft * self.ct + it * zt * xt
self.ht = self.ct if O is None else ot * self.ct
if self.H is None:
self.H = functions.expand_dims(self.ht, 2)
else:
self.H = functions.concat((self.H, functions.expand_dims(self.ht, 2)), axis=2)
return self.H
评论列表
文章目录