def sample_inference_model(self,
source: tf.Tensor, length: tf.Tensor,
samples=1,
reuse: bool=False) -> tf.Tensor:
x = tf.cast(source, tf.int32)
logprops, labels = bytenet_sampling_translator(
x,
beam_size=samples,
**self._parameters,
name="bytenet-model",
reuse=reuse
)
# check if <eos> exists in each sequence
# eos_found.shape = (batch, beam)
eos_found = tf.reduce_any(tf.equal(labels, 1), axis=2)
# set properbility to something very small if <eos> was not found
# log(epsilon) = -1e9
log_eps = tf.constant(-1e9, dtype=logprops.dtype)
logprops = tf.where(eos_found,
logprops,
tf.fill(tf.shape(logprops), log_eps))
# sort by logprops
_, indices = tf.nn.top_k(logprops, k=samples, sorted=True)
labels = batch_beam_gather(labels, indices)
return tf.cast(labels, source.dtype)
评论列表
文章目录