cnn_rnn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:transfer 作者: kimiyoung 项目源码 文件源码
def get_output_for(self, input, **kwargs):
        def norm_fn(f, mask, label, previous, W_sim):
            # f: inst * class, mask: inst, previous: inst * class, W_sim: class * class
            next = previous.dimshuffle(0, 1, 'x') + f.dimshuffle(0, 'x', 1) + W_sim.dimshuffle('x', 0, 1)
            if COST:
                next = next + COST_CONST * (1.0 - T.extra_ops.to_one_hot(label, self.num_classes).dimshuffle(0, 'x', 1))
            # next: inst * prev * cur
            next = theano_logsumexp(next, axis = 1)
            # next: inst * class
            mask = mask.dimshuffle(0, 'x')
            next = previous * (1.0 - mask) + next * mask
            return next

        f = T.dot(input, self.W)
        # f: inst * time * class

        initial = f[:, 0, :]
        if CRF_INIT:
            initial = initial + self.W_init[0].dimshuffle('x', 0)
        if COST:
            initial = initial + COST_CONST * (1.0 - T.extra_ops.to_one_hot(self.label_input[:, 0], self.num_classes))
        outputs, _ = theano.scan(fn = norm_fn, \
         sequences = [f.dimshuffle(1, 0, 2)[1: ], self.mask_input.dimshuffle(1, 0)[1: ], self.label_input.dimshuffle(1, 0)[1:]], \
         outputs_info = initial, non_sequences = [self.W_sim], strict = True)
        norm = T.sum(theano_logsumexp(outputs[-1], axis = 1))

        f_pot = (f.reshape((-1, f.shape[-1]))[T.arange(f.shape[0] * f.shape[1]), self.label_input.flatten()] * self.mask_input.flatten()).sum()
        if CRF_INIT:
            f_pot += self.W_init[0][self.label_input[:, 0]].sum()

        labels = self.label_input
        # labels: inst * time
        shift_labels = T.roll(labels, -1, axis = 1)
        mask = self.mask_input
        # mask : inst * time
        shift_mask = T.roll(mask, -1, axis = 1)

        g_pot = (self.W_sim[labels.flatten(), shift_labels.flatten()] * mask.flatten() * shift_mask.flatten()).sum()

        return - (f_pot + g_pot - norm) / f.shape[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号