test_recurrent_stress_tests.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_initial_state_GRU(self):
        data = np.random.rand(1, 1, 2)

        model = keras.models.Sequential()
        model.add(keras.layers.GRU(5, input_shape=(1, 2), batch_input_shape=[1, 1, 2], stateful=True))
        model.get_layer(index=1).reset_states()

        coreml_model = keras_converter.convert(model=model, input_names='data', output_names='output')
        keras_output_1 = model.predict(data)
        coreml_full_output_1 = coreml_model.predict({'data': data})
        coreml_output_1 = coreml_full_output_1['output']
        coreml_output_1 = np.expand_dims(coreml_output_1, 1)

        np.testing.assert_array_almost_equal(coreml_output_1.T, keras_output_1)

        hidden_state = (np.random.rand(1, 5))
        model.get_layer(index=1).reset_states(states=hidden_state)
        coreml_model = keras_converter.convert(model=model, input_names='data', output_names='output')
        spec = coreml_model.get_spec()
        keras_output_2 = model.predict(data)
        coreml_full_output_2 = coreml_model.predict({'data': data, spec.description.input[1].name: hidden_state[0]})
        coreml_output_2 = coreml_full_output_2['output']
        coreml_output_2 = np.expand_dims(coreml_output_2, 1)
        np.testing.assert_array_almost_equal(coreml_output_2.T, keras_output_2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号