def batch_average(x): '''Sum over all dimensions and averages over the first''' return tf.reduce_mean(tf.reduce_sum(flatten(x),1))