nvdm_nobatch_new.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:NVDM-For-Document-Classification 作者: cryanzpj 项目源码 文件源码
def prediction(x_sample, y_sample): # sample has size 20
            '''
            Get the perplexity of the test set
            '''
            perplist = []
            for i in range(20):
                x_batch_id = [ _ for _ in itertools.compress(range(10000), map(lambda x: x>0,x_sample[0]))]
                feed_dict = {nvdm.input_x: x_sample[i].reshape(1,10000),
                             nvdm.input_y: y_sample[i].reshape(1,103)}
                step, p_xi_h = sess.run([nvdm.global_step, nvdm.p_xi_h], feed_dict)

                valid_p = np.mean(np.log(p_xi_h[x_batch_id]))
                perplist.append(valid_p)
            print("perplexity: {}".format(np.exp(-np.mean(perplist))))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号