Autoencoder.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码
def fit(self, data_stream, 
                nvis=20, 
                nbatch=128,
                niter=1000,
                opt=None,
                save_dir='./'):
        if opt == None: opt = Adam(lr=0.0001)
        if not os.path.exists(save_dir): os.makedirs(save_dir)

        ae = self.autoencoder
        ae.compile(optimizer=opt, loss='mse')

        vis_grid(data_stream().next(), (1, 20), '{}/sample.png'.format(save_dir))

        sampleX = transform(data_stream().next()[:nvis])
        vis_grid(inverse_transform(np.concatenate([sampleX, ae.predict(sampleX)], axis=0)), (2, 20), '{}/sample_generate.png'.format(save_dir))

        def vis_grid_f(epoch, logs):
            vis_grid(inverse_transform(np.concatenate([sampleX, ae.predict(sampleX)], axis=0)), (2, 20), '{}/{}.png'.format(save_dir, epoch))
            if epoch % 50 == 0:
                ae.save_weights('{}/{}_ae_params.h5'.format(save_dir, epoch), overwrite=True)

        def transform_wrapper():
            for data in data_stream():
                yield transform(data), transform(data) 

        ae.fit_generator(transform_wrapper(),
                            samples_per_epoch=nbatch, 
                            nb_epoch=niter, 
                            verbose=1, 
                            callbacks=[LambdaCallback(on_epoch_end=vis_grid_f)],
                        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号