def reshape(x, shape):
def _reshape(x, shape=shape):
return x.view(shape)
def _compute_output_shape(x, shape=shape):
if -1 not in shape:
return shape
else:
n_elems = np.prod(list(_get_shape(x)))
new_shape = list(shape)
new_shape.remove(-1)
new_axis = n_elems // np.prod(new_shape)
s = list(shape)
s[s.index(-1)] = new_axis
return tuple(s)
return get_op(_reshape, output_shape=_compute_output_shape, arguments=shape)(x)
评论列表
文章目录