def test_build_predict_func(self, get_model):
"""Test the build of a model"""
new_session()
X_tr = np.ones((train_samples, input_dim))
model = get_model()
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model_name = model.__class__.__name__
pred_func = KTB.build_predict_func(model)
tensors = [X_tr]
if model_name != 'Model':
tensors.append(1.)
res = pred_func(tensors)
assert len(res[0]) == len(X_tr)
if K.backend() == 'tensorflow':
K.clear_session()
print(self)
评论列表
文章目录