def sample_from_discretized_mix_logistic(l, nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3])
# sample mixture indicator from softmax
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(
logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1, nr_mix])
# select logistic parameters
means = tf.reduce_sum(l[:, :, :, :, :nr_mix] * sel, 4)
log_scales = tf.maximum(tf.reduce_sum(
l[:, :, :, :, nr_mix:2 * nr_mix] * sel, 4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(
l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, 4)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales) * (tf.log(u) - tf.log(1. - u))
x0 = tf.minimum(tf.maximum(x[:, :, :, 0], -1.), 1.)
x1 = tf.minimum(tf.maximum(
x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(
x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, -1.), 1.)
return tf.concat([tf.reshape(x0, xs[:-1] + [1]), tf.reshape(x1, xs[:-1] + [1]), tf.reshape(x2, xs[:-1] + [1])], 3)
评论列表
文章目录