def __generate_tweet_no_unk__(self,
session,
model,
config,
starting_text='<eos>',
stop_tokens=None,
temp=1.0,
CharSize=140):
"""
Private method to generate a sentence.
The sentence will have at maximun 140 characters (a tweet).
We use the list of all noums from
the vocav to eliminate all unk tokens that may occur.
:type session: tf Session
:type model: RNNLanguageModel
:type config: Config
:type starting_text: str
:type stop_tokens: None or list of str
:type temp: float
:rtype : list of str
"""
vocab = self.dataholder.vocab
state = session.run(model.initial_state)
tweet = starting_text.split()
tweet_as_str = starting_text
tokens = [vocab.encode(word) for word in starting_text.split()]
while True:
feed = {model.input_placeholder: [[tokens[-1]]],
model.initial_state: state,
model.dropout_placeholder: 1.0}
state, y_pred = session.run([model.final_state,
model.predictions[-1]],
feed_dict=feed)
next_word_idx = sample(y_pred[0], temperature=temp)
condit1 = vocab.decode(next_word_idx) == self.dataholder.unk_token
condit2 = vocab.decode(next_word_idx) in self.black_list
if condit1 or condit2:
choice = np.random.choice(len(self.dataholder.all_noums), 1)[0]
next_word = self.dataholder.all_noums[choice]
else:
next_word = vocab.decode(next_word_idx)
before_next_word = copy(tweet)
tokens.append(next_word_idx)
tweet.append(next_word)
tweet_as_str = " ".join(tweet)
if len(tweet_as_str) == CharSize:
break
if not TweetValid(tweet_as_str, CharNumber=CharSize):
tweet = copy(before_next_word)
break
if stop_tokens and vocab.decode(tokens[-1]) in stop_tokens:
break
return tweet
评论列表
文章目录