def build(self):
self.output = self._generator(self.input, name='gene')
self.content_loss = tf.reduce_mean(tf.multiply(tf.log1p(self.output),\
tf.abs(tf.subtract(self.target, self.output))))
assert ten_sh(self.output) == ten_sh(self.target)
self.eva_op = tf.concat(1, \
(tf.exp(self.input*12.0)-1, tf.exp(self.output*8.0)-1), name='eva_op')
self.concat_output = tf.exp(tf.concat(1, (self.input, self.output)))
self.concat_target = tf.exp(tf.concat(1, (self.input, self.target)))
self.fake_em = self._critic(self.concat_output, name='critic')
self.true_em = self._critic(self.concat_target, name='critic', reuse=True)
self.c_loss = tf.reduce_mean(self.fake_em - self.true_em, name='c_loss')
self.g_loss = tf.reduce_mean(-self.fake_em, name='g_loss')
####summary####
conntent_loss_sum = tf.summary.scalar('content_loss', self.content_loss)
c_loss_sum = tf.summary.scalar('c_loss', self.c_loss)
g_loss_sum = tf.summary.scalar('g_loss', self.g_loss)
img_sum = tf.summary.image('gene_img', self.concat_output, max_outputs=1)
img_sum = tf.summary.image('tar_img', self.concat_target, max_outputs=1)
self.summary = tf.summary.merge_all()
##############
theta_g = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope='gene')
theta_c = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope='critic')
counter_g = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)
counter_c = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)
self.c_opt = ly.optimize_loss(loss=self.c_loss, learning_rate=self.c_lr,\
optimizer=tf.train.RMSPropOptimizer,\
variables=theta_c,\
global_step=counter_c)
self.g_opt = ly.optimize_loss(self.g_loss, learning_rate=self.g_lr,\
optimizer=tf.train.RMSPropOptimizer,\
variables=theta_g,\
global_step=counter_g)
self.content_opt = ly.optimize_loss(self.content_loss, learning_rate=self.g_lr,\
optimizer=tf.train.RMSPropOptimizer,\
variables=theta_g,\
global_step=counter_g)
clipped_c_var = [tf.assign(var, tf.clip_by_value(var, self.clamp_lower, self.clamp_upper)) \
for var in theta_c]
with tf.control_dependencies([self.c_opt]):
self.c_opt = tf.tuple(clipped_c_var)
评论列表
文章目录