tdlm_model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:topically-driven-language-model 作者: jhlau 项目源码 文件源码
def generate(self, sess, conv_hidden, start_word_id, temperature, max_length, stop_word_id):
        state = sess.run(self.cell.zero_state(1, tf.float32))
        x = [[start_word_id]]
        sent = [start_word_id]

        for _ in xrange(max_length):
            if type(conv_hidden) is np.ndarray:
            #if conv_hidden != None:
                probs, state = sess.run([self.probs, self.state], \
                    {self.x: x, self.initial_state: state, self.conv_hidden: conv_hidden})
            else:
                probs, state = sess.run([self.probs, self.state], \
                    {self.x: x, self.initial_state: state})
            sent.append(self.sample(probs[0], temperature))
            if sent[-1] == stop_word_id:
                break
            x = [[ sent[-1] ]]

        return sent

    #generate a sequence of words, given a topic
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号