TweetGenerator.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:MyTwitterBot 作者: felipessalvatore 项目源码 文件源码
def __generate_tweet_no_unk__(self,
                                  session,
                                  model,
                                  config,
                                  starting_text='<eos>',
                                  stop_tokens=None,
                                  temp=1.0,
                                  CharSize=140):
        """
        Private method to generate a sentence.
        The sentence will have at maximun 140 characters (a tweet).
        We use the list of all noums from
        the vocav to eliminate all unk tokens that may occur.

        :type session: tf Session
        :type model: RNNLanguageModel
        :type config: Config
        :type starting_text: str
        :type stop_tokens: None or list of str
        :type temp: float
        :rtype : list of str
        """
        vocab = self.dataholder.vocab
        state = session.run(model.initial_state)
        tweet = starting_text.split()
        tweet_as_str = starting_text
        tokens = [vocab.encode(word) for word in starting_text.split()]
        while True:
            feed = {model.input_placeholder: [[tokens[-1]]],
                    model.initial_state: state,
                    model.dropout_placeholder: 1.0}
            state, y_pred = session.run([model.final_state,
                                         model.predictions[-1]],
                                        feed_dict=feed)
            next_word_idx = sample(y_pred[0], temperature=temp)
            condit1 = vocab.decode(next_word_idx) == self.dataholder.unk_token
            condit2 = vocab.decode(next_word_idx) in self.black_list
            if condit1 or condit2:
                choice = np.random.choice(len(self.dataholder.all_noums), 1)[0]
                next_word = self.dataholder.all_noums[choice]
            else:
                next_word = vocab.decode(next_word_idx)
            before_next_word = copy(tweet)
            tokens.append(next_word_idx)
            tweet.append(next_word)
            tweet_as_str = " ".join(tweet)
            if len(tweet_as_str) == CharSize:
                break
            if not TweetValid(tweet_as_str, CharNumber=CharSize):
                tweet = copy(before_next_word)
                break
            if stop_tokens and vocab.decode(tokens[-1]) in stop_tokens:
                break
        return tweet
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号