def __build_graph(self):
self.plhold_img, self.plhold_greylevel, self.plhold_latent, self.plhold_is_training, \
self.plhold_keep_prob, self.plhold_kl_weight, self.plhold_lossweights \
= self.model.inputs()
#inference graph
self.op_mean, self.op_stddev, self.op_vae, \
self.op_mean_test, self.op_stddev_test, self.op_vae_test, \
self.op_vae_condinference \
= self.model.inference(self.plhold_img, self.plhold_greylevel, \
self.plhold_latent, self.plhold_is_training, self.plhold_keep_prob)
#loss function and gd step for vae
self.loss = self.model.loss(self.plhold_img, self.op_vae, self.op_mean, \
self.op_stddev, self.plhold_kl_weight, self.plhold_lossweights)
self.train_step = self.model.optimize(self.loss, epsilon=1e-6)
#standard steps
self.check_nan_op = tf.add_check_numerics_ops()
self.init = tf.global_variables_initializer()
self.saver = tf.train.Saver(max_to_keep=0)
self.summary_op = tf.summary.merge_all()
评论列表
文章目录