def apply(self, is_train, tensor1: tf.Tensor, tensor2: tf.Tensor) -> tf.Tensor:
init = get_keras_initialization(self.init)
w1 = tf.get_variable("w1", (tensor1.shape.as_list()[-1], tensor2.shape.as_list()[-1]), initializer=init)
project1 = tf.tensordot(tensor1, w1, [[len(tensor1.shape)-1], [0]])
if self.scale:
project1 /= np.sqrt(tensor1.shape.as_list()[-1])
project1 *= tensor2
elements = [tensor1, project1]
if self.include_unscaled:
elements.append(tensor2)
return tf.concat(elements, axis=len(tensor1.shape) - 1)
评论列表
文章目录