wgan_model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Mendelssohn 作者: diggerdu 项目源码 文件源码
def build(self):
        self.output = self._generator(self.input, name='gene')
        self.content_loss = tf.reduce_mean(tf.multiply(tf.log1p(self.output),\
                tf.abs(tf.subtract(self.target, self.output))))
        assert ten_sh(self.output) == ten_sh(self.target)
        self.eva_op = tf.concat(1, \
            (tf.exp(self.input*12.0)-1, tf.exp(self.output*8.0)-1), name='eva_op')
        self.concat_output  = tf.exp(tf.concat(1, (self.input, self.output)))
        self.concat_target  = tf.exp(tf.concat(1, (self.input, self.target)))
        self.fake_em = self._critic(self.concat_output, name='critic')
        self.true_em = self._critic(self.concat_target, name='critic', reuse=True)
        self.c_loss = tf.reduce_mean(self.fake_em - self.true_em, name='c_loss')
        self.g_loss = tf.reduce_mean(-self.fake_em, name='g_loss')


        ####summary####
        conntent_loss_sum = tf.summary.scalar('content_loss', self.content_loss)
        c_loss_sum = tf.summary.scalar('c_loss', self.c_loss)
        g_loss_sum = tf.summary.scalar('g_loss', self.g_loss)
        img_sum = tf.summary.image('gene_img', self.concat_output, max_outputs=1)
        img_sum = tf.summary.image('tar_img', self.concat_target, max_outputs=1)
        self.summary = tf.summary.merge_all()
        ##############

        theta_g = tf.get_collection(
                         tf.GraphKeys.TRAINABLE_VARIABLES, scope='gene')
        theta_c = tf.get_collection(
                          tf.GraphKeys.TRAINABLE_VARIABLES, scope='critic')
        counter_g = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)
        counter_c = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)
        self.c_opt = ly.optimize_loss(loss=self.c_loss, learning_rate=self.c_lr,\
                optimizer=tf.train.RMSPropOptimizer,\
                variables=theta_c,\
                global_step=counter_c)
        self.g_opt = ly.optimize_loss(self.g_loss, learning_rate=self.g_lr,\
                optimizer=tf.train.RMSPropOptimizer,\
                variables=theta_g,\
                global_step=counter_g)
        self.content_opt = ly.optimize_loss(self.content_loss, learning_rate=self.g_lr,\
                optimizer=tf.train.RMSPropOptimizer,\
                variables=theta_g,\
                global_step=counter_g)
        clipped_c_var = [tf.assign(var, tf.clip_by_value(var, self.clamp_lower, self.clamp_upper)) \
                for var in theta_c]
        with tf.control_dependencies([self.c_opt]):
            self.c_opt = tf.tuple(clipped_c_var)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号