def call(self, x, mask=None):
y = K.dot(x, self.att_W)
if not self.activation:
if K.backend() == 'theano':
weights = K.theano.tensor.tensordot(self.att_v, y, axes=[0, 2])
elif K.backend() == 'tensorflow':
weights = K.tensorflow.python.ops.math_ops.tensordot(self.att_v, y, axes=[0, 2])
elif self.activation == 'tanh':
if K.backend() == 'theano':
weights = K.theano.tensor.tensordot(self.att_v, K.tanh(y), axes=[0, 2])
elif K.backend() == 'tensorflow':
weights = K.tensorflow.python.ops.math_ops.tensordot(self.att_v, K.tanh(y), axes=[0, 2])
weights = K.softmax(weights)
out = x * K.permute_dimensions(K.repeat(weights, x.shape[2]), [0, 2, 1])
if self.op == 'attsum':
out = out.sum(axis=1)
elif self.op == 'attmean':
out = out.sum(axis=1) / mask.sum(axis=1, keepdims=True)
return K.cast(out, K.floatx())
评论列表
文章目录