def __call__(self, x, hs):
batch, dim = x.shape
alphas = 0
_sum = 0
for h in F.transpose_sequence(hs[:batch]):
size = h.shape[0]
if size < batch:
h = F.vstack([h, variable.Variable(
self.xp.zeros((batch - size, h.shape[1]), dtype='f'))])
score = self._score_func(x, h)
e = F.exp(score)
_sum += e
alphas += batch_matmul(h, e)
c = F.reshape(batch_matmul(F.reshape(alphas, (batch, dim)),
(1 / _sum)), (batch, dim))
return c
评论列表
文章目录