test_keras2.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_bidir(self):
        """
        Test the conversion of a bidirectional layer
        """
        from keras.layers import LSTM
        from keras.layers.wrappers import Bidirectional

        # Create a simple Keras model
        model = Sequential()
        model.add(Bidirectional(LSTM(32, input_shape=(10, 32)),
                                input_shape=(10, 32)))

        input_names = ['input']
        output_names = ['output']
        spec = keras.convert(model, input_names, output_names).get_spec()
        self.assertIsNotNone(spec)

        # Test the model class
        self.assertIsNotNone(spec.description)
        self.assertTrue(spec.HasField('neuralNetwork'))

        # Test the inputs and outputs
        self.assertEquals(len(spec.description.input), len(input_names) + 4)
        self.assertEquals(input_names[0], spec.description.input[0].name)

        self.assertEquals(32, spec.description.input[1].type.multiArrayType.shape[0])
        self.assertEquals(32, spec.description.input[2].type.multiArrayType.shape[0])
        self.assertEquals(32, spec.description.input[3].type.multiArrayType.shape[0])
        self.assertEquals(32, spec.description.input[4].type.multiArrayType.shape[0])

        self.assertEquals(len(spec.description.output), len(output_names) + 4)
        self.assertEquals(output_names[0], spec.description.output[0].name)
        self.assertEquals(64, spec.description.output[0].type.multiArrayType.shape[0])

        self.assertEquals(32, spec.description.output[1].type.multiArrayType.shape[0])
        self.assertEquals(32, spec.description.output[2].type.multiArrayType.shape[0])
        self.assertEquals(32, spec.description.output[3].type.multiArrayType.shape[0])
        self.assertEquals(32, spec.description.output[4].type.multiArrayType.shape[0])

        # Test the layer parameters.
        layers = spec.neuralNetwork.layers
        layer_0 = layers[0]
        self.assertIsNotNone(layer_0.biDirectionalLSTM)
        self.assertEquals(len(layer_0.input), 5)
        self.assertEquals(len(layer_0.output), 5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号