layers.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:quora_duplicate 作者: ijinmao 项目源码 文件源码
def __call__(self, x1, x2):
        def _dot_product(args):
            x = args[0]
            y = args[1]
            return K.batch_dot(x, K.permute_dimensions(y, (0, 2, 1)))

        def _normalize(args, transpose=False):
            att_w = args[0]
            x = args[1]
            if transpose:
                att_w = K.permute_dimensions(att_w, (0, 2, 1))
            e = K.exp(att_w - K.max(att_w, axis=-1, keepdims=True))
            sum_e = K.sum(e, axis=-1, keepdims=True)
            nor_e = e / sum_e
            return K.batch_dot(nor_e, x)

        # (batch_size, timesteps1, dim)
        f1 = self.model(x1)
        # (batch_size, timesteps2, dim)
        f2 = self.model(x2)
        output_shape = (self.sequence_length, self.sequence_length,)
        # attention weights, (batch_size, timesteps1, timesteps2)
        att_w = Lambda(
            _dot_product,
            output_shape=output_shape)([f1, f2])
        output_shape = (self.sequence_length, self.input_dim,)
        # (batch_size, timesteps1, dim)
        att1 = Lambda(
            _normalize, arguments={'transpose': False},
            output_shape=output_shape)([att_w, x2])
        # (batch_size, timestep2, dim)
        att2 = Lambda(
            _normalize, arguments={'transpose': True},
            output_shape=output_shape)([att_w, x1])
        return att1, att2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号