def generate(self, sess, conv_hidden, start_word_id, temperature, max_length, stop_word_id):
state = sess.run(self.cell.zero_state(1, tf.float32))
x = [[start_word_id]]
sent = [start_word_id]
for _ in xrange(max_length):
if type(conv_hidden) is np.ndarray:
#if conv_hidden != None:
probs, state = sess.run([self.probs, self.state], \
{self.x: x, self.initial_state: state, self.conv_hidden: conv_hidden})
else:
probs, state = sess.run([self.probs, self.state], \
{self.x: x, self.initial_state: state})
sent.append(self.sample(probs[0], temperature))
if sent[-1] == stop_word_id:
break
x = [[ sent[-1] ]]
return sent
#generate a sequence of words, given a topic
tdlm_model.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录