def batch_flatten(x): '''Turn a n-D tensor into a 2D tensor where the first dimension is conserved. ''' x = T.reshape(x, (x.shape[0], T.prod(x.shape) // x.shape[0])) return x