def flatten(x): def _flatten(x): return x.view([-1]) def _compute_output_shape(x): return (np.prod(list(_get_shape(x))),) return get_op(_flatten, output_shape=_compute_output_shape)(x)