def _train_internal(self, opts):
"""Train a GAN model.
"""
batches_num = self._data.num_points / opts['batch_size']
train_size = self._data.num_points
counter = 0
logging.debug('Training GAN')
for _epoch in xrange(opts["gan_epoch_num"]):
for _idx in xrange(batches_num):
data_ids = np.random.choice(train_size, opts['batch_size'],
replace=False, p=self._data_weights)
batch_images = self._data.data[data_ids].astype(np.float)
batch_noise = utils.generate_noise(opts, opts['batch_size'])
# Update discriminator parameters
for _iter in xrange(opts['d_steps']):
_ = self._session.run(
self._d_optim,
feed_dict={self._real_points_ph: batch_images,
self._noise_ph: batch_noise})
# Update generator parameters
for _iter in xrange(opts['g_steps']):
_ = self._session.run(
self._g_optim, feed_dict={self._noise_ph: batch_noise})
counter += 1
if opts['verbose'] and counter % opts['plot_every'] == 0:
metrics = Metrics()
points_to_plot = self._run_batch(
opts, self._G, self._noise_ph,
self._noise_for_plots[0:320])
data_ids = np.random.choice(train_size, 320,
replace=False,
p=self._data_weights)
metrics.make_plots(
opts, counter,
self._data.data[data_ids],
points_to_plot,
prefix='sample_e%04d_mb%05d_' % (_epoch, _idx))
评论列表
文章目录