rnn.py 文件源码

python
阅读 55 收藏 0 点赞 0 评论 0

项目:rnn-theano 作者: wangxggc 项目源码 文件源码
def __theano_build_train__(self):
        params = self.params
        params_names = self.param_names
        hidden_dim = self.hidden_dim
        batch_size = self.batch_size

        # inputs[0], first sentence.
        # inputs[1], second sentence.
        # inputs[2], encoding target
        inputs = T.itensor3("inputs")
        masks = T.ftensor3("masks")

        def rnn_cell(x, mx, ph, Wh):
            h = T.tanh(ph.dot(Wh) + x)
            h = mx[:, None] * h + (1-mx[:, None]) * ph
            return [h]  # size = sample * hidden : 3 * 4

        # encoding first sentence
        _state = params["E"][inputs[0].flatten(), :].reshape([inputs[0].shape[0], inputs[0].shape[1], hidden_dim])
        _state = _state.dot(params["W"][0]) + params["B"][0]
        [h1], updates = theano.scan(
            fn=rnn_cell,
            sequences=[_state, masks[0]],
            truncate_gradient=self.truncate,
            outputs_info=[dict(initial=T.zeros([batch_size, hidden_dim]))],
            non_sequences=[params["W"][1]])

        # decoding second sentence
        _state = params["E"][inputs[1].flatten(), :].reshape([inputs[1].shape[0], inputs[1].shape[1], hidden_dim])
        _state = _state.dot(params["W"][2]) + params["B"][1]
        [h2], updates = theano.scan(
            fn=rnn_cell,
            sequences=[_state, masks[1]],
            truncate_gradient=self.truncate,
            outputs_info=[dict(initial=h1[-1])],
            non_sequences=[params["W"][3]])

        # Loss
        _s = h2.dot(params["DecodeW"]) + params["DecodeB"]
        _s = _s.reshape([_s.shape[0] * _s.shape[1], _s.shape[2]])
        _s = T.nnet.softmax(_s)
        _cost = T.nnet.categorical_crossentropy(_s, inputs[2].flatten())
        _cost = T.sum(_cost * masks[2].flatten())

        # SGD parameters
        learning_rate = T.scalar("learning_rate")
        decay = T.scalar("decay")

        _grads, _updates = rms_prop(_cost, params_names, params, learning_rate, decay)

        # Assign functions
        self.bptt = theano.function([inputs, masks], _grads)
        self.loss = theano.function([inputs, masks], _cost)
        self.weights = theano.function([inputs, masks], _s)
        self.sgd_step = theano.function(
            [inputs, masks, learning_rate, decay], #theano.In(decay, value=0.9)],
            updates=_updates)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号