def get_output_for(self, inputs, attention_only=False, **kwargs):
# inputs[0]: B x N x D
# inputs[1]: B x Q x D
# inputs[2]: B x N x Q / B x Q x N
# self.mask: B x Q
if self.transpose: M = inputs[2].dimshuffle((0,2,1))
else: M = inputs[2]
alphas = T.nnet.softmax(T.reshape(M, (M.shape[0]*M.shape[1],M.shape[2])))
alphas_r = T.reshape(alphas, (M.shape[0],M.shape[1],M.shape[2]))* \
self.mask[:,np.newaxis,:] # B x N x Q
alphas_r = alphas_r/alphas_r.sum(axis=2)[:,:,np.newaxis] # B x N x Q
q_rep = T.batched_dot(alphas_r, inputs[1]) # B x N x D
return eval(self.gating_fn)(inputs[0],q_rep)
评论列表
文章目录