def mvn_mix_log_probs(samples, q, ndims, num_components=3):
'''Calculate the log probabilities of a MVN mixture model.
Assumes q is [batchsize,numparams]'''
pi = tf.nn.softmax(q[:,:num_components])
mu = tf.reshape(q[:,num_components:num_components*(1+ndims)], [-1, num_components, ndims])
chol_q = q[:,num_components*(1+ndims):]
chol = unpack_cholesky(chol_q, ndims, num_components)
log_probs = []
for c in xrange(num_components):
packed_params = tf.concat(axis=1, values=[mu[:,c,:],tf.reshape(chol[:,c,:,:], [-1,ndims*ndims]), samples])
log_p = tf.map_fn(lambda x: chol_mvn(x[:ndims], tf.reshape(x[ndims:ndims*(1+ndims)],[ndims,ndims])).log_prob(x[ndims*(1+ndims):]), packed_params)
log_probs.append(log_p)
log_probs = tf.transpose(tf.reshape(tf.concat(axis=0, values=log_probs), [num_components, -1]))
log_probs = tf.log(pi)+log_probs
return log_sum_exp(log_probs)
#######################################################################
################ PixelCNN++ utils #####################################
# Some code below taken from OpenAI PixelCNN++ implementation: https://github.com/openai/pixel-cnn
评论列表
文章目录