def save_model(model, filepath, overwrite=True):
def get_json_type(obj):
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__,
'config': obj.get_config()}
if type(obj).__module__ == np.__name__:
return obj.item()
if callable(obj) or type(obj).__name__ == type.__name__:
return obj.__name__
raise TypeError('Not JSON Serializable:', obj)
import h5py
from keras import __version__ as keras_version
if not overwrite and os.path.isfile(filepath):
proceed = keras.models.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
f = h5py.File(filepath, 'w')
f.attrs['keras_version'] = str(keras_version).encode('utf8')
f.attrs['generator_config'] = json.dumps({
'class_name': model.discriminator.__class__.__name__,
'config': model.generator.get_config(),
}, default=get_json_type).encode('utf8')
f.attrs['discriminator_config'] = json.dumps({
'class_name': model.discriminator.__class__.__name__,
'config': model.discriminator.get_config(),
}, default=get_json_type).encode('utf8')
generator_weights_group = f.create_group('generator_weights')
discriminator_weights_group = f.create_group('discriminator_weights')
model.generator.save_weights_to_hdf5_group(generator_weights_group)
model.discriminator.save_weights_to_hdf5_group(discriminator_weights_group)
f.flush()
f.close()
评论列表
文章目录