def get_output_for(self, input, **kwargs): act = T.batched_dot(T.tensordot(input, self.V, axes = [1, 2]), input) + T.dot(input, self.W) + self.b.dimshuffle('x', 0) return self.nonlinearity(act)