def forward(self, x, hidden):
do_dropout = self.training and self.dropout > 0.0
h, c = hidden
h = h.view(h.size(1), -1)
c = c.view(c.size(1), -1)
x = x.view(x.size(1), -1)
# Linear mappings
preact = self.i2h(x) + self.h2h(h)
# activations
gates = preact[:, :3 * self.hidden_size].sigmoid()
g_t = preact[:, 3 * self.hidden_size:].tanh()
i_t = gates[:, :self.hidden_size]
f_t = gates[:, self.hidden_size:2 * self.hidden_size]
o_t = gates[:, -self.hidden_size:]
# cell computations
if do_dropout and self.dropout_method == 'semeniuta':
g_t = F.dropout(g_t, p=self.dropout, training=self.training)
c_t = th.mul(c, f_t) + th.mul(i_t, g_t)
if do_dropout and self.dropout_method == 'moon':
c_t.data.set_(th.mul(c_t, self.mask).data)
c_t.data *= 1.0/(1.0 - self.dropout)
h_t = th.mul(o_t, c_t.tanh())
# Reshape for compatibility
if do_dropout:
if self.dropout_method == 'pytorch':
F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
if self.dropout_method == 'gal':
h_t.data.set_(th.mul(h_t, self.mask).data)
h_t.data *= 1.0/(1.0 - self.dropout)
h_t = h_t.view(1, h_t.size(0), -1)
c_t = c_t.view(1, c_t.size(0), -1)
return h_t, (h_t, c_t)
评论列表
文章目录