covc_encdec.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:CopyNet 作者: MultiPath 项目源码 文件源码
def generate_(self, inputs, mode='display', return_attend=False, return_all=False):
        # assert self.config['sample_stoch'], 'RNNLM sampling must be stochastic'
        # assert not self.config['sample_argmax'], 'RNNLM sampling cannot use argmax'

        args = dict(k=self.config['sample_beam'],
                    maxlen=self.config['max_len'],
                    stochastic=self.config['sample_stoch'] if mode == 'display' else None,
                    argmax=self.config['sample_argmax'] if mode == 'display' else None,
                    return_attend=return_attend)
        context, _, c_mask, _, Z, R = self.encoder.gtenc(inputs)
        # c_mask[0, 3] = c_mask[0, 3] * 0
        # L   = context.shape[1]
        # izz = np.concatenate([np.arange(3), np.asarray([1,2]), np.arange(3, L)])
        # context = context[:, izz, :]
        # c_mask  = c_mask[:, izz]
        # inputs  = inputs[:, izz]
        # context, _, c_mask, _ = self.encoder.encode(inputs)
        # import pylab as plt
        # # visualize_(plt.subplots(), Z[0][:, 300:], normal=False)
        # visualize_(plt.subplots(), context[0], normal=False)

        if 'explicit_loc' in self.config:
            if self.config['explicit_loc']:
                max_len = context.shape[1]
                expLoc  = np.eye(max_len, self.config['encode_max_len'], dtype='float32')[None, :, :]
                expLoc  = np.repeat(expLoc, context.shape[0], axis=0)
                context = np.concatenate([context, expLoc], axis=2)

        sample, score, ppp    = self.decoder.get_sample(context, c_mask, inputs, **args)
        if return_all:
            return sample, score, ppp

        if not args['stochastic']:
            score  = score / np.array([len(s) for s in sample])
            idz    = score.argmin()
            sample = sample[idz]
            score  = score.min()
            ppp    = ppp[idz]
        else:
            score /= float(len(sample))

        return sample, np.exp(score), ppp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号