def train(self):
z_fixed = np.random.normal(size=[self.batch_size*10, self.z_dim]) # samples of 10 times batch size
gen = inf_train_gen(self.lines, self.batch_size, self.charmap)
for step in trange(self.max_step):
# Train generator
_data = gen.next()
summary_str, _ = self.sess.run([self.summary_op, self.g_optim], feed_dict={self.real_data: _data})
self.summary_writer.add_summary(summary_str, global_step=step)
self.summary_writer.flush()
# Train critic
for i in range(self.critic_iters):
_data = gen.next()
self.sess.run(self.d_optim, feed_dict={self.real_data: _data})
if step % 100 == 99:
_data = gen.next()
g_loss, d_loss, slope = self.sess.run([self.g_loss, self.d_loss, self.slope],
feed_dict={self.real_data: _data})
print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} Slope: {:.6f}".
format(step+1, self.max_step, d_loss, g_loss, slope))
self.generate_samples(z_fixed, idx=step+1)
评论列表
文章目录