softattention.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:aes 作者: feidong1991 项目源码 文件源码
def call(self, x, mask=None):
        y = K.dot(x, self.att_W)
        if not self.activation:
            if K.backend() == 'theano':
                weights = K.theano.tensor.tensordot(self.att_v, y, axes=[0, 2])
            elif K.backend() == 'tensorflow':
                weights = K.tensorflow.python.ops.math_ops.tensordot(self.att_v, y, axes=[0, 2])
        elif self.activation == 'tanh':
            if K.backend() == 'theano':
                weights = K.theano.tensor.tensordot(self.att_v, K.tanh(y), axes=[0, 2])
            elif K.backend() == 'tensorflow':
                weights = K.tensorflow.python.ops.math_ops.tensordot(self.att_v, K.tanh(y), axes=[0, 2])
        weights = K.softmax(weights)
        out = x * K.permute_dimensions(K.repeat(weights, x.shape[2]), [0, 2, 1])
        if self.op == 'attsum':
            out = out.sum(axis=1)
        elif self.op == 'attmean':
            out = out.sum(axis=1) / mask.sum(axis=1, keepdims=True)
        return K.cast(out, K.floatx())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号