test_recurrent_stress_tests.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_SimpleRNN(self):
        params = dict(
            input_dims=[1, 2, 100], go_backwards=False, activation='tanh',
            stateful=False, unroll=False, return_sequences=True, output_dim=4  # Passes for < 3
        ),
        model = Sequential()
        if keras.__version__[:2] == '2.':
            model.add(SimpleRNN(units=params[0]['output_dim'],
                                input_shape=(params[0]['input_dims'][1],params[0]['input_dims'][2]),
                                activation=params[0]['activation'],
                                return_sequences=params[0]['return_sequences'],
                                go_backwards=params[0]['go_backwards'],
                                unroll=True,
                                ))            
        else:
            model.add(SimpleRNN(output_dim=params[0]['output_dim'],
                                input_length=params[0]['input_dims'][1],
                                input_dim=params[0]['input_dims'][2],
                                activation=params[0]['activation'],
                                return_sequences=params[0]['return_sequences'],
                                go_backwards=params[0]['go_backwards'],
                                unroll=True,
                                ))
        relative_error, keras_preds, coreml_preds = simple_model_eval(params, model)
        for i in range(len(relative_error)):
            self.assertLessEqual(relative_error[i], 0.01)
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号