feedback.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:sequencing 作者: SwordYork 项目源码 文件源码
def sample(self, logits, log_probs, prev_finished, time):
        """
        sample based on logits.

        :param logits: [_batch_size * beam_size, vocab.vocab_size]
        :param log_probs: [_batch_size * beam_size,], log_probs of current
        decoded sequence.
        :param prev_finished: [_batch_size * beam_size,], indicate each beam
        is finished or not.
        :param time:
        :return:
        """

        # [_batch_size * beam_size, target_vocab_size]
        probs = tf.nn.log_softmax(logits)

        mask_tensor = [tf.float32.max] * self.vocab_size
        mask_tensor[self.eos_id] = -1.
        mask_tensor = tf.expand_dims(tf.constant(mask_tensor,
                                                 dtype=tf.float32), 0)
        mask_probs = (tf.expand_dims(tf.to_float(prev_finished), 1)
                      * mask_tensor + 1.) * probs

        # [_batch_size * beam_size, target_vocab_size]        
        log_probs = mask_probs + tf.expand_dims(log_probs, 1)
        log_probs = tf.reshape(tf.reshape(log_probs, [-1]),
                               [self._batch_size, -1])

        # flatten
        log_probs_flat = tf.cond(
            tf.convert_to_tensor(time) > 0, lambda: log_probs,
            lambda: tf.slice(log_probs, [0, 0], [-1, self.vocab_size]))

        next_log_probs, word_ids = tf.nn.top_k(log_probs_flat, k=self.beam_size)

        next_log_probs = tf.reshape(next_log_probs, [-1])
        word_ids = tf.reshape(word_ids, [-1])

        sample_ids = tf.mod(word_ids, self.vocab_size)

        # beam ids should be adjusted according to _batch_size
        beam_add = tf.tile([tf.range(self._batch_size)],
                           [self.beam_size, 1]) * self.beam_size

        beam_ids = tf.div(word_ids, self.vocab_size) \
                   + tf.reshape(tf.transpose(beam_add), [-1])

        return sample_ids, beam_ids, next_log_probs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号