def train_ez(self):
opti_EZ = tf.train.AdamOptimizer(learning_rate = 0.01, beta1 = 0.5).minimize(self.loss_z,
var_list=self.enz_vars)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
self.saver.restore(sess , self.model_path)
batch_num = 0
e = 0
step = 0
while e <= self.max_epoch:
rand = np.random.randint(0, 100)
rand = 0
while batch_num < len(self.ds_train) / self.batch_size:
step = step + 1
_,label_y = MnistData.getNextBatch(self.ds_train, self.label_y, rand, batch_num,
self.batch_size)
batch_z = np.random.normal(0, 1, size=[self.batch_size, self.sample_size])
# optimization E
sess.run(opti_EZ, feed_dict={self.y: label_y,self.z: batch_z})
batch_num += 1
if step % 10 == 0:
ez_loss = sess.run(self.loss_z, feed_dict={self.y: label_y,self.z: batch_z})
#summary_writer.add_summary(ez_loss, step)
print("EPOCH %d step %d EZ loss %.7f" % (e, step, ez_loss))
if np.mod(step, 50) == 0:
# sample_images = sess.run(self.fake_images, feed_dict={self.e_y:})
# save_images(sample_images[0:64], [8, 8],
# './{}/train_{:02d}_{:04d}.png'.format(self.sample_path, e, step))
self.saver_z.save(sess, self.encode_z_model)
e += 1
batch_num = 0
save_path = self.saver_z.save(sess, self.encode_z_model)
print "Model saved in file: %s" % save_path
评论列表
文章目录