model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:GalaxyGAN_python 作者: Ireneruru 项目源码 文件源码
def __init__(self):
        self.image = tf.placeholder(tf.float32, shape=(1,conf.train_size, conf.train_size, conf.img_channel))
        self.cond = tf.placeholder(tf.float32, shape=(1,conf.train_size, conf.train_size, conf.img_channel))

        self.gen_img = self.generator(self.cond)

        pos = self.discriminator(self.image, self.cond, False)
        neg = self.discriminator(self.gen_img, self.cond, True)
        pos_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pos, labels=tf.ones_like(pos)))
        neg_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.zeros_like(neg)))

        self.delta = tf.square(tf.reduce_mean(self.image)-(tf.reduce_mean(self.gen_img)))

        self.d_loss = pos_loss + neg_loss

        #with regularization
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.ones_like(neg))) + \
                  conf.L1_lambda * tf.reduce_mean(tf.abs(self.image - self.gen_img)) + conf.sum_lambda *self.delta

        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'disc' in var.name]
        self.g_vars = [var for var in t_vars if 'gen' in var.name]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号