def build(self, input_shape):
if not self.recurrent_layer.built:
self.recurrent_layer.build(input_shape)
recurrent_output_shapes = self.recurrent_layer.compute_output_shape(
input_shape
)
if self.return_sequences:
if not self.dense_layer.built:
self.dense_layer.build((
recurrent_output_shapes[0],
recurrent_output_shapes[2]
))
elif not self.dense_layer.built:
self.dense_layer.build(recurrent_output_shapes)
super(RNNCell, self).build(input_shape)
batch_size = input_shape[0] if self.stateful else None
self.dense_state_spec = InputSpec(
shape=(batch_size, self.dense_layer.units)
)
self.dense_state = None
评论列表
文章目录