def test_simple_rnn(self):
"""
Test the conversion of a simple RNN layer.
"""
from keras.layers import SimpleRNN
# Create a simple Keras model
model = Sequential()
model.add(SimpleRNN(32, input_shape=(10,32)))
input_names = ['input']
output_names = ['output']
spec = keras.convert(model, input_names, output_names).get_spec()
self.assertIsNotNone(spec)
# Test the model class
self.assertIsNotNone(spec.description)
self.assertTrue(spec.HasField('neuralNetwork'))
# Test the inputs and outputs
self.assertEquals(len(spec.description.input), len(input_names) + 1)
self.assertEquals(input_names[0], spec.description.input[0].name)
self.assertEquals(32, spec.description.input[1].type.multiArrayType.shape[0])
self.assertEquals(len(spec.description.output), len(output_names) + 1)
self.assertEquals(output_names[0], spec.description.output[0].name)
self.assertEquals(32, spec.description.output[0].type.multiArrayType.shape[0])
self.assertEquals(32, spec.description.output[1].type.multiArrayType.shape[0])
# Test the layer parameters.
layers = spec.neuralNetwork.layers
layer_0 = layers[0]
self.assertIsNotNone(layer_0.simpleRecurrent)
self.assertEquals(len(layer_0.input), 2)
self.assertEquals(len(layer_0.output), 2)
评论列表
文章目录