models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:punctuator2 作者: ottokart 项目源码 文件源码
def __init__(self, rng, x, minibatch_size, n_hidden, x_vocabulary, y_vocabulary, stage1_model_file_name, p=None):

        y_vocabulary_size = len(y_vocabulary)

        self.stage1_model_file_name = stage1_model_file_name
        self.stage1, _ = load(stage1_model_file_name, minibatch_size, x)

        self.n_hidden = n_hidden
        self.x_vocabulary = x_vocabulary
        self.y_vocabulary = y_vocabulary

        # output model
        self.GRU = GRULayer(rng=rng, n_in=self.stage1.n_hidden + 1, n_out=n_hidden, minibatch_size=minibatch_size)
        self.Wy = weights_const(n_hidden, y_vocabulary_size, 'Wy', 0)
        self.by = weights_const(1, y_vocabulary_size, 'by', 0)

        self.params = [self.Wy, self.by]
        self.params += self.GRU.params

        def recurrence(x_t, p_t, h_tm1, Wy, by):

            h_t = self.GRU.step(x_t=T.concatenate((x_t, p_t.dimshuffle((0, 'x'))), axis=1), h_tm1=h_tm1)

            z = T.dot(h_t, Wy) + by
            y_t = T.nnet.softmax(z)

            return [h_t, y_t]

        [_, self.y], _ = theano.scan(fn=recurrence,
            sequences=[self.stage1.last_hidden_states, p],
            non_sequences=[self.Wy, self.by],
            outputs_info=[self.GRU.h0, None])

        print "Number of parameters is %d" % sum(np.prod(p.shape.eval()) for p in self.params)
        print "Number of parameters with stage1 params is %d" % sum(np.prod(p.shape.eval()) for p in self.params + self.stage1.params)

        self.L1 = sum(abs(p).sum() for p in self.params)
        self.L2_sqr = sum((p**2).sum() for p in self.params)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号