def std(x, axis=None, keepdims=False):
def _std(x, axis, keepdims):
y = torch.std(x, axis)
# Since keepdims argument of torch not functional
return y if keepdims else torch.squeeze(y, axis)
def _compute_output_shape(x, axis, keepdims):
if axis is None:
return ()
shape = list(_get_shape(x))
if keepdims:
shape[axis] = 1
else:
del shape[axis]
return tuple(shape)
return get_op(_std, output_shape=_compute_output_shape, arguments=[axis, keepdims])(x)
评论列表
文章目录