similarity_layers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def _distance_logits(self, x1, x2):
        init = get_keras_initialization(self.init)

        project1 = tf.get_variable("project1", (x1.shape.as_list()[-1], self.project_size), initializer=init)
        x1 = tf.tensordot(x1, project1, [[2], [0]])

        if self.share_project:
            if x2.shape.as_list()[-1] != x1.shape.as_list()[-1]:
                raise ValueError()
            project2 = project1
        else:
            project2 = tf.get_variable("project2", (x2.shape.as_list()[-1], self.project_size), initializer=init)
        x2 = tf.tensordot(x2, project2, [[2], [0]])

        if self.project_bias:
            x1 += tf.get_variable("bias1", (1, 1, self.project_size), initializer=tf.zeros_initializer())
            x2 += tf.get_variable("bias2", (1, 1, self.project_size), initializer=tf.zeros_initializer())

        dots = tf.matmul(x1, x2, transpose_b=True)
        if self.scale:
            dots /= tf.sqrt(tf.cast(self.project_size, tf.float32))
        return dots
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号