def prediction(x_sample, y_sample): # sample has size 20
'''
Get the perplexity of the test set
'''
perplist = []
for i in range(20):
x_batch_id = [ _ for _ in itertools.compress(range(10000), map(lambda x: x>0,x_sample[0]))]
feed_dict = {nvdm.input_x: x_sample[i].reshape(1,10000)}
step, p_xi_h = sess.run([nvdm.global_step, nvdm.p_xi_h], feed_dict)
valid_p = np.mean(np.log(p_xi_h[x_batch_id]))
perplist.append(valid_p)
print("perplexity: {}".format(np.exp(-np.mean(perplist))))
nvdm_nobatch.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录