def _distance_logits(self, x1, x2):
init = get_keras_initialization(self.init)
project1 = tf.get_variable("project1", (x1.shape.as_list()[-1], self.project_size), initializer=init)
x1 = tf.tensordot(x1, project1, [[2], [0]])
if self.share_project:
if x2.shape.as_list()[-1] != x1.shape.as_list()[-1]:
raise ValueError()
project2 = project1
else:
project2 = tf.get_variable("project2", (x2.shape.as_list()[-1], self.project_size), initializer=init)
x2 = tf.tensordot(x2, project2, [[2], [0]])
if self.project_bias:
x1 += tf.get_variable("bias1", (1, 1, self.project_size), initializer=tf.zeros_initializer())
x2 += tf.get_variable("bias2", (1, 1, self.project_size), initializer=tf.zeros_initializer())
dots = tf.matmul(x1, x2, transpose_b=True)
if self.scale:
dots /= tf.sqrt(tf.cast(self.project_size, tf.float32))
return dots
评论列表
文章目录