def build(self, input_shape):
input_shape = list(input_shape)
input_shape = input_shape[:1] + [self.output_length] + input_shape[1:]
if not self.hidden_dim:
self.hidden_dim = input_shape[-1]
output_dim = input_shape[-1]
self.output_dim = self.hidden_dim
initial_weights = self.initial_weights
self.initial_weights = None
super(LSTMDecoder, self).build(input_shape)
self.output_dim = output_dim
self.initial_weights = initial_weights
self.W_y = self.init((self.hidden_dim, self.output_dim), name='{}_W_y'.format(self.name))
self.b_y = K.zeros((self.output_dim), name='{}_b_y'.format(self.name))
self.trainable_weights += [self.W_y, self.b_y]
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
input_shape.pop(1)
self.input_spec = [InputSpec(shape=tuple(input_shape))]
评论列表
文章目录