def test_saving_lambda_custom_objects():
input = Input(shape=(3,))
x = Lambda(lambda x: square_fn(x), output_shape=(3,))(input)
output = Dense(3)(x)
model = Model(input, output)
model.compile(loss=objectives.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
_, fname = tempfile.mkstemp('.h5')
save_model(model, fname)
model = load_model(fname, custom_objects={'square_fn': square_fn})
os.remove(fname)
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
评论列表
文章目录