torch_utils.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def where(condition, if_true, if_false):
    """
    Torch equivalent of numpy.where.

    Parameters
    ----------
    condition : torch.ByteTensor or torch.cuda.ByteTensor or torch.autograd.Variable
        Condition to check.
    if_true : torch.Tensor or torch.cuda.Tensor or torch.autograd.Variable
        Output value if condition is true.
    if_false: torch.Tensor or torch.cuda.Tensor or torch.autograd.Variable
        Output value if condition is false

    Returns
    -------
    torch.Tensor

    Raises
    ------
    AssertionError
        if if_true and if_false are not both variables or both tensors.
    AssertionError
        if if_true and if_false don't have the same datatype.
    """
    if isinstance(if_true, Variable) or isinstance(if_false, Variable):
        assert isinstance(condition, Variable), \
            "Condition must be a variable if either if_true or if_false is a variable."
        assert isinstance(if_false, Variable) and isinstance(if_false, Variable), \
            "Both if_true and if_false must be variables if either is one."
        assert if_true.data.type() == if_false.data.type(), \
            "Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type())
    else:
        assert not isinstance(condition, Variable), \
            "Condition must not be a variable because neither if_true nor if_false is one."
        # noinspection PyArgumentList
        assert if_true.type() == if_false.type(), \
            "Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type())
    casted_condition = condition.type_as(if_true)
    output = casted_condition * if_true + (1 - casted_condition) * if_false
    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号