test_recurrent_stress_tests.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_SimpleGRU(self):
        params = dict(
            input_dims=[1, 4, 8], go_backwards=False, activation='tanh',
            stateful=False, unroll=False, return_sequences=False, output_dim=4
        ),
        model = Sequential()
        if keras.__version__[:2] == '2.':
             model.add(GRU(units=params[0]['output_dim'],
                           input_shape=(params[0]['input_dims'][1],params[0]['input_dims'][2]),
                           activation=params[0]['activation'],
                           recurrent_activation='sigmoid',
                           return_sequences=params[0]['return_sequences'],
                           go_backwards=params[0]['go_backwards'],
                           unroll=True,
                           ))
        else:
            model.add(GRU(output_dim=params[0]['output_dim'],
                          input_length=params[0]['input_dims'][1],
                          input_dim=params[0]['input_dims'][2],
                          activation=params[0]['activation'],
                          inner_activation='sigmoid',
                          return_sequences=params[0]['return_sequences'],
                          go_backwards=params[0]['go_backwards'],
                          unroll=True,
                          ))
        model.set_weights([np.random.rand(*w.shape) for w in model.get_weights()])
        relative_error, keras_preds, coreml_preds = simple_model_eval(params, model)
        for i in range(len(relative_error)):
            self.assertLessEqual(relative_error[i], 0.01)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号