def sample(self, logits, log_probs, prev_finished, time):
"""
sample based on logits.
:param logits: [_batch_size * beam_size, vocab.vocab_size]
:param log_probs: [_batch_size * beam_size,], log_probs of current
decoded sequence.
:param prev_finished: [_batch_size * beam_size,], indicate each beam
is finished or not.
:param time:
:return:
"""
# [_batch_size * beam_size, target_vocab_size]
probs = tf.nn.log_softmax(logits)
mask_tensor = [tf.float32.max] * self.vocab_size
mask_tensor[self.eos_id] = -1.
mask_tensor = tf.expand_dims(tf.constant(mask_tensor,
dtype=tf.float32), 0)
mask_probs = (tf.expand_dims(tf.to_float(prev_finished), 1)
* mask_tensor + 1.) * probs
# [_batch_size * beam_size, target_vocab_size]
log_probs = mask_probs + tf.expand_dims(log_probs, 1)
log_probs = tf.reshape(tf.reshape(log_probs, [-1]),
[self._batch_size, -1])
# flatten
log_probs_flat = tf.cond(
tf.convert_to_tensor(time) > 0, lambda: log_probs,
lambda: tf.slice(log_probs, [0, 0], [-1, self.vocab_size]))
next_log_probs, word_ids = tf.nn.top_k(log_probs_flat, k=self.beam_size)
next_log_probs = tf.reshape(next_log_probs, [-1])
word_ids = tf.reshape(word_ids, [-1])
sample_ids = tf.mod(word_ids, self.vocab_size)
# beam ids should be adjusted according to _batch_size
beam_add = tf.tile([tf.range(self._batch_size)],
[self.beam_size, 1]) * self.beam_size
beam_ids = tf.div(word_ids, self.vocab_size) \
+ tf.reshape(tf.transpose(beam_add), [-1])
return sample_ids, beam_ids, next_log_probs
评论列表
文章目录