classifier.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:narrative-prediction 作者: roemmele 项目源码 文件源码
def predict(self, seq1, seq2, pred_method='multiply', unigram_probs=None):

        '''right now this function only handles getting prob for one sequence pair'''
        if self.flat_input:
            if self.embedded_input:
                seq1 = seq1[None]
            else:
                seq1 = get_vector_batch([seq1], vector_length=self.lexicon_size+1)
        else:
            seq1 = get_seq_batch([seq1], max_length=self.n_timesteps)

        probs = self.model.predict_on_batch(seq1)[0]

        if self.flat_output:
            if unigram_probs is not None:
                probs = probs / unigram_probs ** 0.66
                probs[numpy.isinf(probs)] = 0.0 #replace inf
            #import pdb;pdb.set_trace()
            seq2 = get_vector_batch([seq2], vector_length=self.lexicon_size+1)
            #prob = 1 - cosine(seq2, probs)
            probs = probs[seq2[0].astype('bool')]

        else:
            seq2 = get_seq_batch([seq2], padding='post', max_length=self.n_timesteps)

            probs = probs[numpy.arange(self.n_timesteps), seq2]
            probs = probs[seq2 > 0]

        if pred_method == 'multiply':
            prob = numpy.sum(numpy.log(probs))
            #prob = numpy.multiply(probs)
        if pred_method == 'mean':
            #prob = numpy.sum(numpy.log(probs))
            prob = numpy.mean(numpy.log(probs))
        elif pred_method == 'last':
            prob = numpy.log(probs[-1])
        elif pred_method == 'max':
            prob = numpy.log(numpy.max(probs))
        return prob
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号