models.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:punctuator2 作者: ottokart 项目源码 文件源码
def __init__(self, rng, n_in, n_out, minibatch_size):
        super(GRULayer, self).__init__()
        # Notation from: An Empirical Exploration of Recurrent Network Architectures

        self.n_in = n_in
        self.n_out = n_out

        # Initial hidden state
        self.h0 = theano.shared(value=np.zeros((minibatch_size, n_out)).astype(theano.config.floatX), name='h0', borrow=True)

        # Gate parameters:
        self.W_x = weights_Glorot(n_in, n_out*2, 'W_x', rng)
        self.W_h = weights_Glorot(n_out, n_out*2, 'W_h', rng)
        self.b = weights_const(1, n_out*2, 'b', 0)
        # Input parameters
        self.W_x_h = weights_Glorot(n_in, n_out, 'W_x_h', rng)
        self.W_h_h = weights_Glorot(n_out, n_out, 'W_h_h', rng)
        self.b_h = weights_const(1, n_out, 'b_h', 0)

        self.params = [self.W_x, self.W_h, self.b, self.W_x_h, self.W_h_h, self.b_h]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号