def _build(self, inputs, state):
hidden, cell = state
input_conv = self._convolutions["input"]
hidden_conv = self._convolutions["hidden"]
next_hidden = input_conv(inputs) + hidden_conv(hidden)
gates = tf.split(value=next_hidden, num_or_size_splits=4,
axis=self._conv_ndims+1)
input_gate, next_input, forget_gate, output_gate = gates
next_cell = tf.sigmoid(forget_gate + self._forget_bias) * cell
next_cell += tf.sigmoid(input_gate) * tf.tanh(next_input)
output = tf.tanh(next_cell) * tf.sigmoid(output_gate)
if self._skip_connection:
output = tf.concat([output, inputs], axis=-1)
return output, (output, next_cell)
评论列表
文章目录