def expand_dims(x, axis=-1):
def _expand_dims(x, axis=axis):
return torch.unsqueeze(x, axis)
def _compute_output_shape(x, axis=axis):
shape = list(_get_shape(x))
shape.insert(axis, 1)
return shape
return get_op(_expand_dims, output_shape=_compute_output_shape, arguments=[axis])(x)
评论列表
文章目录