def log_sum_exp(x, axis=1): m = T.max(x, axis=axis) return m+T.log(T.sum(T.exp(x-m.dimshuffle(0,'x')), axis=axis))