gan.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def _train_internal(self, opts):
        """Train a GAN model.

        """

        batches_num = self._data.num_points / opts['batch_size']
        train_size = self._data.num_points

        counter = 0
        logging.debug('Training GAN')
        for _epoch in xrange(opts["gan_epoch_num"]):
            for _idx in xrange(batches_num):
                data_ids = np.random.choice(train_size, opts['batch_size'],
                                            replace=False, p=self._data_weights)
                batch_images = self._data.data[data_ids].astype(np.float)
                batch_noise = utils.generate_noise(opts, opts['batch_size'])
                # Update discriminator parameters
                for _iter in xrange(opts['d_steps']):
                    _ = self._session.run(
                        self._d_optim,
                        feed_dict={self._real_points_ph: batch_images,
                                   self._noise_ph: batch_noise})
                # Update generator parameters
                for _iter in xrange(opts['g_steps']):
                    _ = self._session.run(
                        self._g_optim, feed_dict={self._noise_ph: batch_noise})
                counter += 1
                if opts['verbose'] and counter % opts['plot_every'] == 0:
                    metrics = Metrics()
                    points_to_plot = self._run_batch(
                        opts, self._G, self._noise_ph,
                        self._noise_for_plots[0:320])
                    data_ids = np.random.choice(train_size, 320,
                                                replace=False,
                                                p=self._data_weights)
                    metrics.make_plots(
                        opts, counter,
                        self._data.data[data_ids],
                        points_to_plot,
                        prefix='sample_e%04d_mb%05d_' % (_epoch, _idx))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号