def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units),
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.tied_k:
k_size = (1,)
else:
k_size = (self.units,)
self.k = self.add_weight(shape=k_size,
initializer=self.k_initializer,
name='k',
regularizer=self.k_regularizer,
constraint=self.k_constraint)
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
评论列表
文章目录