def apply(self, inputs, states, cells, mask=None):
def slice_last(x, no):
return x[:, no * self.dim: (no + 1) * self.dim]
activation = tensor.dot(states, self.W_state) + inputs
in_gate = self.gate_activation.apply(
slice_last(activation, 0))
pre = slice_last(activation, 1)
forget_gate = self.gate_activation.apply(
pre + self.bias * tensor.ones_like(pre))
next_cells = (
forget_gate * cells +
in_gate * self.activation.apply(slice_last(activation, 2)))
out_gate = self.gate_activation.apply(
slice_last(activation, 3))
next_states = out_gate * self.activation.apply(next_cells)
if mask:
next_states = (mask[:, None] * next_states +
(1 - mask[:, None]) * states)
next_cells = (mask[:, None] * next_cells +
(1 - mask[:, None]) * cells)
return next_states, next_cells
评论列表
文章目录