test_recurrent_stress_tests.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_SimpleLSTMStacked(self):
        params = dict(
            input_dims=[1, 1, 1], go_backwards=False, activation='tanh',
            stateful=False, unroll=False, return_sequences=False, output_dim=1
        ),
        model = Sequential()
        model.add(LSTM(output_dim=params[0]['output_dim'],
                       input_length=params[0]['input_dims'][1],
                       input_dim=params[0]['input_dims'][2],
                       activation=params[0]['activation'],
                       inner_activation='sigmoid',
                       return_sequences=True,
                       go_backwards=params[0]['go_backwards'],
                       unroll=params[0]['unroll'],
                       ))
        model.add(LSTM(output_dim=1,
                       activation='tanh',
                       inner_activation='sigmoid',
                       ))
        relative_error, keras_preds, coreml_preds = simple_model_eval(params, model)
        for i in range(len(relative_error)):
            self.assertLessEqual(relative_error[i], 0.01)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号