def call(self, inputs, mask=None):
assert(isinstance(inputs, list) and len(inputs) == 5)
uQ, WQ_u, WQ_v, v, VQ_r = inputs
uQ_mask = mask[0] if mask is not None else None
ones = K.ones_like(K.sum(uQ, axis=1, keepdims=True)) # (B, 1, 2H)
s_hat = K.dot(uQ, WQ_u)
s_hat += K.dot(ones, K.dot(WQ_v, VQ_r))
s_hat = K.tanh(s_hat)
s = K.dot(s_hat, v)
s = K.batch_flatten(s)
a = softmax(s, mask=uQ_mask, axis=1)
rQ = K.batch_dot(uQ, a, axes=[1, 1])
return rQ
评论列表
文章目录