def fit(self, data_stream,
nvis=20,
nbatch=128,
niter=1000,
opt=None,
save_dir='./'):
if opt == None: opt = Adam(lr=0.0001)
if not os.path.exists(save_dir): os.makedirs(save_dir)
ae = self.autoencoder
ae.compile(optimizer=opt, loss='mse')
vis_grid(data_stream().next(), (1, 20), '{}/sample.png'.format(save_dir))
sampleX = transform(data_stream().next()[:nvis])
vis_grid(inverse_transform(np.concatenate([sampleX, ae.predict(sampleX)], axis=0)), (2, 20), '{}/sample_generate.png'.format(save_dir))
def vis_grid_f(epoch, logs):
vis_grid(inverse_transform(np.concatenate([sampleX, ae.predict(sampleX)], axis=0)), (2, 20), '{}/{}.png'.format(save_dir, epoch))
if epoch % 50 == 0:
ae.save_weights('{}/{}_ae_params.h5'.format(save_dir, epoch), overwrite=True)
def transform_wrapper():
for data in data_stream():
yield transform(data), transform(data)
ae.fit_generator(transform_wrapper(),
samples_per_epoch=nbatch,
nb_epoch=niter,
verbose=1,
callbacks=[LambdaCallback(on_epoch_end=vis_grid_f)],
)
评论列表
文章目录