wgan.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def __init__(self, batch_size):
        self.C = Critic(batch_size)
        self.G = Generator(batch_size)

        self.X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
        self.p = tf.placeholder(tf.float32, name="p")

        self.gen_img = self.G()

        g_logits = self.C(self.gen_img, self.p)

        self.g_loss = -tf.reduce_mean(g_logits)
        self.c_loss = tf.reduce_mean(-self.C(self.X, self.p, reuse=True) + g_logits)
        #self.g_loss = tf.reduce_mean(tf.reduce_sum(g_logits, axis=1))
        #self.c_loss = tf.reduce_mean(tf.reduce_sum(-self.C(self.X, self.p, reuse=True) + g_logits, axis=1))

        c_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        g_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        c_grads_and_vars = c_opt.compute_gradients(self.c_loss)
        g_grads_and_vars = g_opt.compute_gradients(self.g_loss)

        c_grads_and_vars = [[grad, var] for grad, var in c_grads_and_vars \
                            if grad is not None and var.name.startswith("C") ]
        g_grads_and_vars = [[grad, var] for grad, var in g_grads_and_vars \
                            if grad is not None and var.name.startswith("G") ]

        self.c_train_op = c_opt.apply_gradients(c_grads_and_vars)
        self.g_train_op = g_opt.apply_gradients(g_grads_and_vars)

        self.w_clip = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) \
                       for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="C")]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号