def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `MatchTensor` layer should be called '
'on a list of 2 inputs.')
shape1 = input_shape[0]
shape2 = input_shape[1]
if shape1[0] != shape2[0]:
raise ValueError(
'Dimension incompatibility '
'%s != %s. ' % (shape1[0], shape2[0]) +
'Layer shapes: %s, %s' % (shape1, shape2))
if self.init_diag:
if shape1[2] != shape2[2]:
raise ValueError( 'Use init_diag need same embedding shape.' )
M_diag = np.float32(np.random.uniform(-0.05, 0.05, [self.channel, shape1[2], shape2[2]]))
for i in range(self.channel):
for j in range(shape1[2]):
M_diag[i][j][j] = 1.0
self.M = self.add_weight( name='M',
shape=(self.channel, shape1[2], shape2[2]),
initializer=M_diag,
trainable=True )
else:
self.M = self.add_weight( name='M',
shape=(self.channel, shape1[2], shape2[2]),
initializer='uniform',
trainable=True )
评论列表
文章目录