gan.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def _train_internal(self, opts):
        """Train a GAN model.

        """

        batches_num = self._data.num_points / opts['batch_size']
        train_size = self._data.num_points

        counter = 0
        logging.debug('Training GAN')
        for _epoch in xrange(opts["gan_epoch_num"]):
            for _idx in xrange(batches_num):
                # logging.debug('Step %d of %d' % (_idx, batches_num ) )
                data_ids = np.random.choice(train_size, opts['batch_size'],
                                            replace=False, p=self._data_weights)
                batch_images = self._data.data[data_ids].astype(np.float)
                batch_noise = utils.generate_noise(opts, opts['batch_size'])
                # Update discriminator parameters
                for _iter in xrange(opts['d_steps']):
                    _ = self._session.run(
                        self._d_optim,
                        feed_dict={self._real_points_ph: batch_images,
                                   self._noise_ph: batch_noise,
                                   self._is_training_ph: True})
                # Update generator parameters
                for _iter in xrange(opts['g_steps']):
                    _ = self._session.run(
                        self._g_optim,
                        feed_dict={self._noise_ph: batch_noise,
                                   self._is_training_ph: True})
                counter += 1

                if opts['verbose'] and counter % opts['plot_every'] == 0:
                    logging.debug(
                        'Epoch: %d/%d, batch:%d/%d' % \
                        (_epoch+1, opts['gan_epoch_num'], _idx+1, batches_num))
                    metrics = Metrics()
                    points_to_plot = self._run_batch(
                        opts, self._G, self._noise_ph,
                        self._noise_for_plots[0:320],
                        self._is_training_ph, False)
                    metrics.make_plots(
                        opts,
                        counter,
                        None,
                        points_to_plot,
                        prefix='sample_e%04d_mb%05d_' % (_epoch, _idx))
                if opts['early_stop'] > 0 and counter > opts['early_stop']:
                    break
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号