def _train_internal(self, opts):
"""Train a GAN model.
"""
batches_num = self._data.num_points / opts['batch_size']
train_size = self._data.num_points
counter = 0
logging.debug('Training GAN')
for _epoch in xrange(opts["gan_epoch_num"]):
for _idx in xrange(batches_num):
# logging.debug('Step %d of %d' % (_idx, batches_num ) )
data_ids = np.random.choice(train_size, opts['batch_size'],
replace=False, p=self._data_weights)
batch_images = self._data.data[data_ids].astype(np.float)
batch_noise = utils.generate_noise(opts, opts['batch_size'])
# Update discriminator parameters
for _iter in xrange(opts['d_steps']):
_ = self._session.run(
self._d_optim,
feed_dict={self._real_points_ph: batch_images,
self._noise_ph: batch_noise,
self._is_training_ph: True})
# Update generator parameters
for _iter in xrange(opts['g_steps']):
_ = self._session.run(
self._g_optim,
feed_dict={self._noise_ph: batch_noise,
self._is_training_ph: True})
counter += 1
if opts['verbose'] and counter % opts['plot_every'] == 0:
logging.debug(
'Epoch: %d/%d, batch:%d/%d' % \
(_epoch+1, opts['gan_epoch_num'], _idx+1, batches_num))
metrics = Metrics()
points_to_plot = self._run_batch(
opts, self._G, self._noise_ph,
self._noise_for_plots[0:320],
self._is_training_ph, False)
metrics.make_plots(
opts,
counter,
None,
points_to_plot,
prefix='sample_e%04d_mb%05d_' % (_epoch, _idx))
if opts['early_stop'] > 0 and counter > opts['early_stop']:
break
评论列表
文章目录