def call(self, inputs):
x1 = inputs[0]
x2 = inputs[1]
if isinstance(self.axes, int):
if self.axes < 0:
axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)]
else:
axes = [self.axes] * 2
else:
axes = []
for i in range(len(self.axes)):
if self.axes[i] < 0:
axes.append(self.axes[i] % K.ndim(inputs[i]))
else:
axes.append(self.axes[i])
if self.normalize:
x1 = K.l2_normalize(x1, axis=axes[0])
x2 = K.l2_normalize(x2, axis=axes[1])
output = K.batch_dot(x1, x2, axes)
return output
评论列表
文章目录