rnn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: eske 项目源码 文件源码
def __init__(self, num_units, activation=None, reuse=None, kernel_initializer=None, bias_initializer=None,
                 layer_norm=False, state_keep_prob=None, input_keep_prob=None, input_size=None, final=False):
        super(DropoutGRUCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or tf.nn.tanh
        self._kernel_initializer = kernel_initializer
        self._bias_initializer = bias_initializer
        self._layer_norm = layer_norm
        self._state_keep_prob = state_keep_prob
        self._input_keep_prob = input_keep_prob
        self._final = final

        def batch_noise(s):
            s = tf.concat(([1], tf.TensorShape(s).as_list()), 0)
            return tf.random_uniform(s)

        if input_keep_prob is not None:
            self._input_noise = DropoutGRUCell._enumerated_map_structure(lambda i, s: batch_noise(s), input_size)
        if state_keep_prob is not None:
            self._state_noise = DropoutGRUCell._enumerated_map_structure(lambda i, s: batch_noise(s), num_units)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号