ntm_encdec.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:CopyNet 作者: MultiPath 项目源码 文件源码
def compile_inference(self):
        """
        build the hidden action prediction.
        """
        inputs         = T.imatrix()  # padded input word sequence (for training)

        if self.config['mode']   == 'RNN':
            context    = alloc_zeros_matrix(inputs.shape[0], self.config['enc_contxt_dim'])
        elif self.config['mode'] == 'NTM':
            context    = T.repeat(self.memory[None, :, :], inputs.shape[0], axis=0)
        else:
            raise NotImplementedError

        # encoding
        memorybook     = self.encoder.build_encoder(inputs, context)

        # get Q(a|y) = sigmoid(.|Posterior * encoded)
        q_dis          = self.Post(memorybook)
        p_dis          = self.Prior()

        self.inference_ = theano.function([inputs], [memorybook, q_dis, p_dis])
        logger.info("inference function compile done.")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号