def sample_encoded_context(self, embeddings):
'''Helper function for init_opt'''
c_mean_logsigma = self.model.generate_condition(embeddings)
mean = c_mean_logsigma[0]
if cfg.TRAIN.COND_AUGMENTATION:
# epsilon = tf.random_normal(tf.shape(mean))
epsilon = tf.truncated_normal(tf.shape(mean))
stddev = tf.exp(c_mean_logsigma[1])
c = mean + stddev * epsilon
kl_loss = KL_loss(c_mean_logsigma[0], c_mean_logsigma[1])
else:
c = mean
kl_loss = 0
return c, cfg.TRAIN.COEFF.KL * kl_loss
评论列表
文章目录