basic_rnn_cells.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:skiprnn-2017-telecombcn 作者: imatge-upc 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """Gated recurrent unit (GRU) with num_units cells."""
        with tf.variable_scope(scope or type(self).__name__):
            with tf.variable_scope("gates"):  # Reset gate and update gate.
                # We start with bias of 1.0 to not reset and not update.
                concat = rnn_ops.linear([inputs, state], 2 * self._num_units, True, bias_start=1.0)
                r, u = tf.split(value=concat, num_or_size_splits=2, axis=1)

                if self._layer_norm:
                    r = rnn_ops.layer_norm(r, name="r")
                    u = rnn_ops.layer_norm(u, name="u")

                # Apply non-linearity after layer normalization
                r = tf.sigmoid(r)
                u = tf.sigmoid(u)

            with tf.variable_scope("candidate"):
                c = self._activation(rnn_ops.linear([inputs, r * state], self._num_units, True))
            new_h = u * state + (1 - u) * c
        return new_h, new_h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号