def generate_text(session, model, config, starting_text='<eos>',
stop_length=100, stop_tokens=None, temp=1.0):
"""
This function uses the model to generate a sentence
starting with the token(s) "starting_text".
The generated sentence has at most "stop_length" tokens.
If you use the list "stop_tokens", the sentence will end at any
word of that list.
:type session: tf Session
:type model: RNNLanguageModel
:type config: Config
:type starting_text: str
:type stop_lenght: int
:type stop_tokens: None or list of str
:type temp: float
:rtype : list of str
"""
state = session.run(model.initial_state)
tokens = [model.vocab.encode(word) for word in starting_text.split()]
for i in range(stop_length):
feed = {model.input_placeholder: [[tokens[-1]]],
model.initial_state: state,
model.dropout_placeholder: 1.0}
state, y_pred = session.run([model.final_state,
model.predictions[-1]],
feed_dict=feed)
next_word_idx = sample(y_pred[0], temperature=temp)
tokens.append(next_word_idx)
if stop_tokens and model.vocab.decode(tokens[-1]) in stop_tokens:
break
output = [model.vocab.decode(word_idx) for word_idx in tokens]
return output
generate_functions.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录