def get_output_for(self, inputs, **kwargs):
# inputs[0]: B x N x D, doc
# inputs[1]: B x Q x D, query
# self.aggregator: B x N x C
# self.pointer: B x 1
# self.mask: B x N
q = inputs[1][T.arange(inputs[1].shape[0]),self.pointer,:] # B x D
p = T.batched_dot(inputs[0],q) # B x N
pm = T.nnet.softmax(p)*self.mask # B x N
pm = pm/pm.sum(axis=1)[:,np.newaxis] # B x N
return T.batched_dot(pm, self.aggregator)
评论列表
文章目录