tdlm_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:topically-driven-language-model 作者: jhlau 项目源码 文件源码
def gen_sent_on_topic(idxvocab, vocabxid, start_symbol, end_symbol, cf):
    output = codecs.open(args.gen_sent_on_topic, "w", "utf-8")
    topics, entropy = tm.get_topics(sess, topn=topn)
    with tf.variable_scope("model", reuse=True, initializer=initializer):
        mgen = LM(is_training=False, vocab_size=len(idxvocab), batch_size=1, num_steps=1, config=cf, \
            reuse_conv_variables=True)

    for t in range(cf.topic_number):
        output.write("\n" + "="*100 + "\n")
        output.write("Topic " +  str(t) + ":\n")
        output.write(" ".join([ idxvocab[item] for item in topics[t] ]) + "\n\n")

        output.write("\nSentence generation (greedy; argmax):" + "\n")
        s = mgen.generate_on_topic(sess, t, vocabxid[start_symbol], 0, cf.lm_sent_len+10, vocabxid[end_symbol])
        output.write("[0] " + " ".join([ idxvocab[item] for item in s ]) + "\n")

        for temp in gen_temps:
            output.write("\nSentence generation (random; temperature = " + str(temp) + "):\n")
            for i in xrange(gen_num):
                s = mgen.generate_on_topic(sess, t, vocabxid[start_symbol], temp, cf.lm_sent_len+10, \
                    vocabxid[end_symbol])
                output.write("[" + str(i) + "] " +  " ".join([ idxvocab[item] for item in s ]) + "\n")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号