def test_rnn_layer(self):
i = 0
numerical_err_models = []
shape_err_models = []
numerical_failiure = 0
for base_params in self.base_layer_params:
base_params = dict(zip(self.params_dict.keys(), base_params))
for rnn_params in self.rnn_layer_params:
rnn_params = dict(zip(self.simple_rnn_params_dict.keys(), rnn_params))
model = Sequential()
model.add(
SimpleRNN(
base_params['output_dim'],
input_length=base_params['input_dims'][1],
input_dim=base_params['input_dims'][2],
activation=base_params['activation'],
return_sequences=base_params['return_sequences'],
go_backwards=base_params['go_backwards'],
unroll=base_params['unroll'],
)
)
mlkitmodel = get_mlkit_model_from_path(model)
input_data = generate_input(base_params['input_dims'][0], base_params['input_dims'][1],
base_params['input_dims'][2])
keras_preds = model.predict(input_data).flatten()
if K.tensorflow_backend._SESSION:
import tensorflow as tf
tf.reset_default_graph()
K.tensorflow_backend._SESSION.close()
K.tensorflow_backend._SESSION = None
input_data = np.transpose(input_data, [1, 0, 2])
coreml_preds = mlkitmodel.predict({'data': input_data})['output'].flatten()
try:
self.assertEquals(coreml_preds.shape, keras_preds.shape)
except AssertionError:
print("Shape error:\nbase_params: {}\nkeras_preds.shape: {}\ncoreml_preds.shape: {}".format(
base_params, keras_preds.shape, coreml_preds.shape))
shape_err_models.append(base_params)
i += 1
continue
try:
max_denominator = np.maximum(np.maximum(np.abs(coreml_preds), np.abs(keras_preds)), 1.0)
relative_error = coreml_preds / max_denominator - keras_preds / max_denominator
for i in range(len(relative_error)):
self.assertLessEqual(relative_error[i], 0.01)
except AssertionError:
print("Assertion error:\nbase_params: {}\nkeras_preds: {}\ncoreml_preds: {}".format(base_params,
keras_preds,
coreml_preds))
numerical_failiure += 1
numerical_err_models.append(base_params)
i += 1
self.assertEquals(shape_err_models, [], msg='Shape error models {}'.format(shape_err_models))
self.assertEquals(numerical_err_models, [], msg='Numerical error models {}\n'
'Total numerical failiures: {}/{}\n'.format(
numerical_err_models,
numerical_failiure, i)
)
评论列表
文章目录