cnn_rnn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:transfer 作者: kimiyoung 项目源码 文件源码
def get_output_for(self, input, **kwargs):
        def max_fn(f, mask, prev_score, prev_back, W_sim):
            next_score = prev_score.dimshuffle(0, 1, 'x') + f.dimshuffle(0, 'x', 1) + W_sim.dimshuffle('x', 0, 1)
            next_back = T.argmax(next_score, axis = 1)
            next_score = T.max(next_score, axis = 1)
            mask = mask.dimshuffle(0, 'x')
            next_score = next_score * mask + prev_score * (1.0 - mask)
            next_back = next_back * mask + prev_back * (1.0 - mask)
            next_back = T.cast(next_back, 'int32')
            return [next_score, next_back]

        def produce_fn(back, mask, prev_py):
            # back: inst * class, prev_py: inst, mask: inst
            next_py = back[T.arange(prev_py.shape[0]), prev_py]
            next_py = mask * next_py + (1.0 - mask) * prev_py
            next_py = T.cast(next_py, 'int32')
            return next_py

        f = T.dot(input, self.W)

        init_score, init_back = f[:, 0, :], T.zeros_like(f[:, 0, :], dtype = 'int32')
        if CRF_INIT:
            init_score = init_score + self.W_init[0].dimshuffle('x', 0)
        ([scores, backs], _) = theano.scan(fn = max_fn, \
            sequences = [f.dimshuffle(1, 0, 2)[1: ], self.mask_input.dimshuffle(1, 0)[1: ]], \
            outputs_info = [init_score, init_back], non_sequences = [self.W_sim], strict = True)

        init_py = T.argmax(scores[-1], axis = 1)
        init_py = T.cast(init_py, 'int32')
        # init_py: inst, backs: time * inst * class
        pys, _ = theano.scan(fn = produce_fn, \
            sequences = [backs, self.mask_input.dimshuffle(1, 0)[1:]], outputs_info = [init_py], go_backwards = True)
        # pys: (rev_time - 1) * inst
        pys = pys.dimshuffle(1, 0)[:, :: -1]
        # pys : inst * (time - 1)
        return T.concatenate([pys, init_py.dimshuffle(0, 'x')], axis = 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号