models.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:gandlf 作者: codekansas 项目源码 文件源码
def save_model(model, filepath, overwrite=True):

    def get_json_type(obj):
        if hasattr(obj, 'get_config'):
            return {'class_name': obj.__class__.__name__,
                    'config': obj.get_config()}

        if type(obj).__module__ == np.__name__:
            return obj.item()

        if callable(obj) or type(obj).__name__ == type.__name__:
            return obj.__name__

        raise TypeError('Not JSON Serializable:', obj)

    import h5py
    from keras import __version__ as keras_version

    if not overwrite and os.path.isfile(filepath):
        proceed = keras.models.ask_to_proceed_with_overwrite(filepath)
        if not proceed:
            return

    f = h5py.File(filepath, 'w')
    f.attrs['keras_version'] = str(keras_version).encode('utf8')
    f.attrs['generator_config'] = json.dumps({
        'class_name': model.discriminator.__class__.__name__,
        'config': model.generator.get_config(),
    }, default=get_json_type).encode('utf8')
    f.attrs['discriminator_config'] = json.dumps({
        'class_name': model.discriminator.__class__.__name__,
        'config': model.discriminator.get_config(),
    }, default=get_json_type).encode('utf8')

    generator_weights_group = f.create_group('generator_weights')
    discriminator_weights_group = f.create_group('discriminator_weights')
    model.generator.save_weights_to_hdf5_group(generator_weights_group)
    model.discriminator.save_weights_to_hdf5_group(discriminator_weights_group)

    f.flush()
    f.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号