def get_output_for(self, inputs, **kwargs):
# inputs[0]: B x N x D
# inputs[1]: B x Q x D
# self.mask: B x Q
q_shuf = inputs[1].dimshuffle(0,2,1) # B x D x Q
return T.batched_dot(inputs[0], q_shuf) # B x N x Q