def compute_dt_dist(docs, labels, tags, model, max_len, batch_size, pad_id, idxvocab, output_file):
#generate batches
num_batches = int(math.ceil(float(len(docs)) / batch_size))
dt_dist = []
t = []
combined = []
docid = 0
for i in xrange(num_batches):
x, _, _, t, s = get_batch_doc(docs, labels, tags, i, max_len, cf.tag_len, batch_size, pad_id)
attention, mean_topic = sess.run([model.attention, model.mean_topic], {model.doc: x, model.tag: t})
dt_dist.extend(attention[:s])
if debug:
for si in xrange(s):
d = x[si]
print "\n\nDoc", docid, "=", " ".join([idxvocab[item] for item in d if (item != pad_id)])
sorted_dist = matutils.argsort(attention[si], reverse=True)
for ti in sorted_dist:
print "Topic", ti, "=", attention[si][ti]
docid += 1
np.save(open(output_file, "w"), dt_dist)
tdlm_test.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录