covc_encdec.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:seq2seq-keyphrase 作者: memray 项目源码 文件源码
def generate_multiple(self, inputs, mode='display', return_attend=False, return_all=True, return_encoding=False):
        # assert self.config['sample_stoch'], 'RNNLM sampling must be stochastic'
        # assert not self.config['sample_argmax'], 'RNNLM sampling cannot use argmax'
        args = dict(k=self.config['sample_beam'],
                    maxlen=self.config['max_len'],
                    stochastic=self.config['sample_stoch'] if mode == 'display' else None,
                    argmax=self.config['sample_argmax'] if mode == 'display' else None,
                    return_attend=return_attend,
                    type=self.config['predict_type']
                    )
        '''
        Return the encoding of input.
            Similar to encoder.encode(), but gate values are returned as well
            I think only gtenc with attention
            default: with_context=False, return_sequence=True, return_embed=True
        '''

        """
        return
            context:  a list of vectors [nb_sample, max_len, 2*enc_hidden_dim], encoding of each time state (concatenate both forward and backward RNN)
            _:      embedding of text X [nb_sample, max_len, enc_embedd_dim]
            c_mask: mask, an array showing which elements in context are not 0 [nb_sample, max_len]
            _: encoding of end of X, seems not make sense for bidirectional model (head+tail) [nb_sample, 2*enc_hidden_dim]
            Z:  value of update gate, shape=(nb_sample, 1)
            R:  value of update gate, shape=(nb_sample, 1)
        but.. Z and R are not used here
        """
        context, _, c_mask, _, Z, R = self.encoder.gtenc(inputs)
        # c_mask[0, 3] = c_mask[0, 3] * 0
        # L   = context.shape[1]
        # izz = np.concatenate([np.arange(3), np.asarray([1,2]), np.arange(3, L)])
        # context = context[:, izz, :]
        # c_mask  = c_mask[:, izz]
        # inputs  = inputs[:, izz]
        # context, _, c_mask, _ = self.encoder.encode(inputs)
        # import pylab as plt
        # # visualize_(plt.subplots(), Z[0][:, 300:], normal=False)
        # visualize_(plt.subplots(), context[0], normal=False)

        if 'explicit_loc' in self.config: # no
            if self.config['explicit_loc']:
                max_len = context.shape[1]
                expLoc  = np.eye(max_len, self.config['encode_max_len'], dtype='float32')[None, :, :]
                expLoc  = np.repeat(expLoc, context.shape[0], axis=0)
                context = np.concatenate([context, expLoc], axis=2)

        sample, score, ppp, output_encoding    = self.decoder.get_sample(context, c_mask, inputs, **args)
        if return_all:
            if return_encoding:
                return context, sample, score, output_encoding
            else:
                return sample, score
        return sample, score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号