gated_rnn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def _build(self, inputs, state):
    hidden, cell = state
    input_conv = self._convolutions["input"]
    hidden_conv = self._convolutions["hidden"]
    next_hidden = input_conv(inputs) + hidden_conv(hidden)
    gates = tf.split(value=next_hidden, num_or_size_splits=4,
                     axis=self._conv_ndims+1)

    input_gate, next_input, forget_gate, output_gate = gates
    next_cell = tf.sigmoid(forget_gate + self._forget_bias) * cell
    next_cell += tf.sigmoid(input_gate) * tf.tanh(next_input)
    output = tf.tanh(next_cell) * tf.sigmoid(output_gate)

    if self._skip_connection:
      output = tf.concat([output, inputs], axis=-1)
    return output, (output, next_cell)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号