dcgan.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def train(self, config=None):
        #mnist = input_data.read_data_sets("/tmp/tensorflow/mnist/input_dat", one_hot=True)

        loader = Loader(config.data_dir, config.data, config.batch_size)

        loaded = False
        if not config.reset:
            loaded, global_step = self.restore(config.checkpoint_dir)
        if not loaded:
            tf.global_variables_initializer().run()
            global_step = 0

        d_losses = []
        g_losses = []
        steps = []
        gif = []
        for epoch in range(config.epoch):
            loader.reset()
            #for idx in range(config.step):
            for idx in range(loader.batch_num):
                #batch_X, _ = mnist.train.next_batch(config.batch_size)
                #batch_X = batch_X.reshape([-1]+self.in_dim)
                batch_X = np.asarray(loader.next_batch(), dtype=np.float32)
                #batch_X = (batch_X*255.-127.5)/127.5
                batch_X = (batch_X - 127.5)/127.5
                batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim])

                _, d_loss = self.sess.run([self.d_train_op, self.d_loss],
                              feed_dict={self.X: batch_X, self.z: batch_z})
                _, g_loss = self.sess.run([self.g_train_op, self.g_loss],
                              feed_dict={self.z: batch_z})
                d_losses.append(d_loss)
                g_losses.append(g_loss)
                steps.append(global_step)
                global_step += 1

            print(" [Epoch {}] d_loss:{}, g_loss:{}".format(epoch, d_loss, g_loss))
            batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim])
            imgs = self.sess.run(self.sampler, feed_dict={self.z: batch_z})
            gif.append(visualize(imgs, epoch, config.data))
            self.save("{}_{}".format(config.checkpoint_dir, config.data), global_step, model_name="dcgan")

        plot({'d_loss':d_losses, 'g_loss':g_losses}, steps, title="DCGAN loss ({})".format(config.data), x_label="Step", y_label="Loss")
        save_gif(gif, "gen_img_{}".format(config.data))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号