def flatten(output):
sh = tf.unpack(tf.shape(output))
batch, output_shape = sh[0], sh[1:]
flat_shape = 1
for d in output_shape:
flat_shape *= d
return tf.reshape(output, tf.pack([batch, flat_shape]))