trainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:StackGAN 作者: hanzhanggit 项目源码 文件源码
def sample_encoded_context(self, embeddings):
        '''Helper function for init_opt'''
        # Build conditioning augmentation structure for text embedding
        # under different variable_scope: 'g_net' and 'hr_g_net'
        c_mean_logsigma = self.model.generate_condition(embeddings)
        mean = c_mean_logsigma[0]
        if cfg.TRAIN.COND_AUGMENTATION:
            # epsilon = tf.random_normal(tf.shape(mean))
            epsilon = tf.truncated_normal(tf.shape(mean))
            stddev = tf.exp(c_mean_logsigma[1])
            c = mean + stddev * epsilon

            kl_loss = KL_loss(c_mean_logsigma[0], c_mean_logsigma[1])
        else:
            c = mean
            kl_loss = 0
        # TODO: play with the coefficient for KL
        return c, cfg.TRAIN.COEFF.KL * kl_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号