def train(self, config=None):
#mnist = input_data.read_data_sets("/tmp/tensorflow/mnist/input_dat", one_hot=True)
loader = Loader(config.data_dir, config.data, config.batch_size)
loaded = False
if not config.reset:
loaded, global_step = self.restore(config.checkpoint_dir)
if not loaded:
tf.global_variables_initializer().run()
global_step = 0
d_losses = []
g_losses = []
steps = []
gif = []
for epoch in range(config.epoch):
loader.reset()
#for idx in range(config.step):
for idx in range(loader.batch_num):
#batch_X, _ = mnist.train.next_batch(config.batch_size)
#batch_X = batch_X.reshape([-1]+self.in_dim)
batch_X = np.asarray(loader.next_batch(), dtype=np.float32)
#batch_X = (batch_X*255.-127.5)/127.5
batch_X = (batch_X - 127.5)/127.5
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim])
_, d_loss = self.sess.run([self.d_train_op, self.d_loss],
feed_dict={self.X: batch_X, self.z: batch_z})
_, g_loss = self.sess.run([self.g_train_op, self.g_loss],
feed_dict={self.z: batch_z})
d_losses.append(d_loss)
g_losses.append(g_loss)
steps.append(global_step)
global_step += 1
print(" [Epoch {}] d_loss:{}, g_loss:{}".format(epoch, d_loss, g_loss))
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim])
imgs = self.sess.run(self.sampler, feed_dict={self.z: batch_z})
gif.append(visualize(imgs, epoch, config.data))
self.save("{}_{}".format(config.checkpoint_dir, config.data), global_step, model_name="dcgan")
plot({'d_loss':d_losses, 'g_loss':g_losses}, steps, title="DCGAN loss ({})".format(config.data), x_label="Step", y_label="Loss")
save_gif(gif, "gen_img_{}".format(config.data))
评论列表
文章目录