python类is_tensor()的实例源码

__init__.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def _wrap_function(function, ffi):
    @wraps(function)
    def safe_call(*args, **kwargs):
        args = tuple(ffi.cast(_torch_to_cffi.get(type(arg), 'void') + '*', arg._cdata)
                if torch.is_tensor(arg) or torch.is_storage(arg)
                else arg
                for arg in args)
        args = (function,) + args
        result = torch._C._safe_call(*args, **kwargs)
        if isinstance(result, ffi.CData):
            typeof = ffi.typeof(result)
            if typeof.kind == 'pointer':
                cdata = int(ffi.cast('uintptr_t', result))
                cname = typeof.item.cname
                if cname in _cffi_to_torch:
                    return _cffi_to_torch[cname](cdata=cdata)
        return result
    return safe_call
dataloader.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def default_collate(batch):
    "Puts each data field into a tensor with outer dimension batch size"
    if torch.is_tensor(batch[0]):
        return torch.cat([t.view(1, *t.size()) for t in batch], 0)
    elif isinstance(batch[0], int):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], str):
        return batch
    elif isinstance(batch[0], collections.Iterable):
        # if each batch element is not a tensor, then it should be a tuple
        # of tensors; in that case we collate each element in the tuple
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(("batch must contain tensors, numbers, or lists; found {}"
                     .format(type(batch[0]))))
utils.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def clear(self, *args):
    if len(args) == 1 and isinstance(args[0], list):
        args = args[0]
    def _clear(f):
        if not hasattr(self, f):
            return
        attr = getattr(self, f)
        if torch.is_tensor(attr):
            attr.set_()
        elif isinstance(attr, list):
            del attr[:]
        else:
            setattr(self, f, None)
    for key in args:
        _clear(key)
    return self
common.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.creator is None
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
aucmeter.py 文件源码 项目:tnt 作者: pytorch 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def add(self, output, target):
        if torch.is_tensor(output):
            output = output.cpu().squeeze().numpy()
        if torch.is_tensor(target):
            target = target.cpu().squeeze().numpy()
        elif isinstance(target, numbers.Number):
            target = np.asarray([target])
        assert np.ndim(output) == 1, \
            'wrong output size (1D expected)'
        assert np.ndim(target) == 1, \
            'wrong target size (1D expected)'
        assert output.shape[0] == target.shape[0], \
            'number of outputs and targets does not match'
        assert np.all(np.add(np.equal(target, 1), np.equal(target, 0))), \
            'targets should be binary (0, 1)'

        self.scores = np.append(self.scores, output)
        self.targets = np.append(self.targets, target)
dictionary.py 文件源码 项目:ParlAI 作者: facebookresearch 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def string(self, tensor, bpe_symbol=None, escape_unk=False):
        """Helper for converting a tensor of token indices to a string.

        Can optionally remove BPE symbols or escape <unk> words.
        """
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return '\n'.join(self.string(t) for t in tensor)

        def token_string(i):
            if i == self.unk():
                return self.unk_string(escape_unk)
            else:
                return self[i]

        sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
        if bpe_symbol is not None:
            sent = sent.replace(bpe_symbol, '')
        return sent
common.py 文件源码 项目:pyro 作者: uber 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.is_leaf
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        return Variable(obj.data.clone().type(
            t), requires_grad=obj.requires_grad)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
torch_utils.py 文件源码 项目:inferno 作者: inferno-pytorch 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def unwrap(tensor_or_variable, to_cpu=True, as_numpy=False):
    if isinstance(tensor_or_variable, (list, tuple)):
        return type(tensor_or_variable)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)
                                         for _t in tensor_or_variable])
    elif isinstance(tensor_or_variable, Variable):
        tensor = tensor_or_variable.data
    elif torch.is_tensor(tensor_or_variable):
        tensor = tensor_or_variable
    elif isinstance(tensor_or_variable, np.ndarray):
        return tensor_or_variable
    elif isinstance(tensor_or_variable, (float, int)):
        return tensor_or_variable
    else:
        raise NotUnwrappableError("Cannot unwrap a '{}'."
                                  .format(type(tensor_or_variable).__name__))
    # Transfer to CPU if required
    if to_cpu:
        with delayed_keyboard_interrupt():
            tensor = tensor.cpu()
    # Convert to numpy if required
    if as_numpy:
        return tensor.cpu().numpy()
    else:
        return tensor
tensorport.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def create_torch_variable(self, value, gpu=False):
        """Convenience method that produces a tensor given the value of the defined type.

        Returns: a torch tensor of same type.
        """
        if isinstance(value, torch.autograd.Variable):
            if gpu:
                value = value.cuda()
            return value
        if not torch.is_tensor(value):
            if not isinstance(value, np.ndarray):
                value = np.array(value, dtype=self.dtype.as_numpy_dtype)
            else:
                value = value.astype(self.dtype.as_numpy_dtype)
            if value.size == 0:
                return value
            allowed = [tf.int16, tf.int32, tf.int64, tf.float16, tf.float32, tf.float64, tf.int8]
            if self.dtype in allowed:
                value = torch.autograd.Variable(torch.from_numpy(value))
        else:
            value = torch.autograd.Variable(value)
        if gpu and isinstance(value, torch.autograd.Variable):
            value = value.cuda()
        return value
dictionary.py 文件源码 项目:fairseq-py 作者: facebookresearch 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def string(self, tensor, bpe_symbol=None, escape_unk=False):
        """Helper for converting a tensor of token indices to a string.

        Can optionally remove BPE symbols or escape <unk> words.
        """
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return '\n'.join(self.string(t) for t in tensor)

        def token_string(i):
            if i == self.unk():
                return self.unk_string(escape_unk)
            else:
                return self[i]

        sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
        if bpe_symbol is not None:
            sent = sent.replace(bpe_symbol, '')
        return sent
train_utils.py 文件源码 项目:kaggle-carvana 作者: ematvey 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def gpu_preloader_iter(dataloader):
    loader_iter = iter(dataloader)
    bx, by = None, None
    while 1:
        try:
            x, y = bx, by
            bx, by = next(loader_iter)
            if torch.is_tensor(bx):
                bx = bx.cuda(async=True)
            if torch.is_tensor(by):
                by = by.cuda(async=True)
            if x is None or y is None:
                x, y = next(loader_iter)
                if torch.is_tensor(x):
                    x = x.cuda()
                if torch.is_tensor(y):
                    y = y.cuda()
            yield x, y
        except StopIteration:
            if bx is not None:
                yield bx, by
            return
__init__.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def _wrap_function(function, ffi):
    @wraps(function)
    def safe_call(*args, **kwargs):
        args = tuple(ffi.cast(_torch_to_cffi.get(type(arg), 'void') + '*', arg._cdata)
                     if torch.is_tensor(arg) or torch.is_storage(arg)
                     else arg
                     for arg in args)
        args = (function,) + args
        result = torch._C._safe_call(*args, **kwargs)
        if isinstance(result, ffi.CData):
            typeof = ffi.typeof(result)
            if typeof.kind == 'pointer':
                cdata = int(ffi.cast('uintptr_t', result))
                cname = typeof.item.cname
                if cname in _cffi_to_torch:
                    return _cffi_to_torch[cname](cdata=cdata)
        return result
    return safe_call
read_lua_file.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _load_backend(obj):
    if hasattr(obj, '_type'):
        obj._backend = type2backend[obj._type]
        return
    # Try to find tensor attributes and infer type from them
    for key in dir(obj):
        attr = getattr(obj, key)
        if torch.is_tensor(attr):
            try:
                obj._backend = type2backend[type(attr)]
            except KeyError:
                pass
    # Monkey patch the forward to capture the type of input
    updateOutput_orig = obj.updateOutput

    def updateOutput_patch(*args):
        input = args[0]
        while not torch.is_tensor(input):
            input = input[0]
        obj._backend = type2backend[type(input)]
        obj.updateOutput = updateOutput_orig
        return obj.updateOutput(*args)
    obj.updateOutput = updateOutput_patch
scatter_gather.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def scatter(inputs, target_gpus, dim=0):
    """
    Slices variables into approximately equal chunks and
    distributes them accross given GPUs. Duplicates
    references to objects that are not variables. Does not
    support Tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, Variable):
            return Scatter(target_gpus, dim=dim)(obj)
        assert not torch.is_tensor(obj), "Tensors not supported in scatter."
        if isinstance(obj, tuple):
            return tuple(zip(*map(scatter_map, obj)))
        if isinstance(obj, list):
            return tuple(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict):
            return tuple(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return tuple(obj for targets in target_gpus)

    return scatter_map(inputs)
utils.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def clear(self, *args):
    if len(args) == 1 and isinstance(args[0], list):
        args = args[0]

    def _clear(f):
        if not hasattr(self, f):
            return
        attr = getattr(self, f)
        if torch.is_tensor(attr):
            attr.set_()
        elif isinstance(attr, list):
            del attr[:]
        else:
            setattr(self, f, None)
    for key in args:
        _clear(key)
    return self
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def create_input(call_args, requires_grad=True):
    if not isinstance(call_args, tuple):
        call_args = (call_args,)

    def map_arg(arg):
        if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
            return arg
        elif isinstance(arg, tuple) and not isinstance(arg[0], Variable):
            return Variable(torch.randn(*arg).double(), requires_grad=requires_grad)
        elif torch.is_tensor(arg):
            if isinstance(arg, torch.FloatTensor):
                return Variable(arg.double(), requires_grad=requires_grad)
            else:
                return Variable(arg, requires_grad=requires_grad)
        else:
            return arg
    return tuple(map_arg(arg) for arg in call_args)
common.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 38 收藏 0 点赞 0 评论 0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.is_leaf
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
common.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def safeCoalesce(self, t):
        tc = t.coalesce()

        value_map = {}
        for idx, val in zip(t._indices().t(), t._values()):
            idx_tup = tuple(idx)
            if idx_tup in value_map:
                value_map[idx_tup] += val
            else:
                value_map[idx_tup] = val.clone() if torch.is_tensor(val) else val

        new_indices = sorted(list(value_map.keys()))
        new_values = [value_map[idx] for idx in new_indices]
        if t._values().ndimension() < 2:
            new_values = t._values().new(new_values)
        else:
            new_values = torch.stack(new_values)

        new_indices = t._indices().new(new_indices).t()
        tg = t.new(new_indices, new_values, t.size())

        self.assertEqual(tc._indices(), tg._indices())
        self.assertEqual(tc._values(), tg._values())

        return tg
misc.py 文件源码 项目:ml-utils 作者: LinxiFan 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def shapes_all(data):
    """
    Recursively walks the data (can be tuples, lists, or dict) and
    replaces a tensor with its shape tuple whenever it meets a tensor
    """
    if isinstance(data, (tuple, list)):
        ans = map(shapes_all, data)
        return type(data)(ans)
    elif isinstance(data, dict):
        return {k: shapes_all(v) for k, v in data.items()}
    elif (isinstance(data, np.ndarray)
          or torch.is_tensor(data)
          or isinstance(data, torch.autograd.Variable)
          or isinstance(data, torch.nn.Parameter)):
        return shape(data)
    else:
        return data
misc.py 文件源码 项目:ml-utils 作者: LinxiFan 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def to_float_tensor(x, copy=True):
    """
    FloatTensor is the most used torch type, so we create a special method for it
    """
    if torch.is_tensor(x):
        assert isinstance(x, torch.FloatTensor)
        return x
    elif TC.is_variable(x):
        x = TC.to_tensor(x)
        assert isinstance(x, torch.FloatTensor)
        return x
    elif not TC.is_numpy(x):
        x = np.array(x, copy=False)
    x = np_cast(x, np.float32)
    if copy:
        return torch.FloatTensor(x)
    else:
        return torch.from_numpy(x)
voc0712.py 文件源码 项目:ssd_pytorch 作者: miraclebiu 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    targets = []
    imgs = []
    for _, sample in enumerate(batch):
        for _, tup in enumerate(sample):
            #pdb.set_trace()
            if torch.is_tensor(tup):
                imgs.append(tup)
            elif isinstance(tup, type([])):
                annos = [torch.Tensor(a) for a in tup]
                #pdb.set_trace()
                targets.append(torch.stack(annos, 0))

    return (torch.stack(imgs, 0), targets)
__init__.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _wrap_function(function, ffi):
    @wraps(function)
    def safe_call(*args, **kwargs):
        args = tuple(ffi.cast(_torch_to_cffi.get(type(arg), 'void') + '*', arg._cdata)
                     if torch.is_tensor(arg) or torch.is_storage(arg)
                     else arg
                     for arg in args)
        args = (function,) + args
        result = torch._C._safe_call(*args, **kwargs)
        if isinstance(result, ffi.CData):
            typeof = ffi.typeof(result)
            if typeof.kind == 'pointer':
                cdata = int(ffi.cast('uintptr_t', result))
                cname = typeof.item.cname
                if cname in _cffi_to_torch:
                    return _cffi_to_torch[cname](cdata=cdata)
        return result
    return safe_call
read_lua_file.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _load_backend(obj):
    if hasattr(obj, '_type'):
        obj._backend = type2backend[obj._type]
        return
    # Try to find tensor attributes and infer type from them
    for key in dir(obj):
        attr = getattr(obj, key)
        if torch.is_tensor(attr):
            try:
                obj._backend = type2backend[type(attr)]
            except KeyError:
                pass
    # Monkey patch the forward to capture the type of input
    updateOutput_orig = obj.updateOutput

    def updateOutput_patch(*args):
        input = args[0]
        while not torch.is_tensor(input):
            input = input[0]
        obj._backend = type2backend[type(input)]
        obj.updateOutput = updateOutput_orig
        return obj.updateOutput(*args)
    obj.updateOutput = updateOutput_patch
scatter_gather.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def scatter(inputs, target_gpus, dim=0):
    """
    Slices variables into approximately equal chunks and
    distributes them accross given GPUs. Duplicates
    references to objects that are not variables. Does not
    support Tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, Variable):
            return Scatter(target_gpus, dim=dim)(obj)
        assert not torch.is_tensor(obj), "Tensors not supported in scatter."
        if isinstance(obj, tuple):
            return tuple(zip(*map(scatter_map, obj)))
        if isinstance(obj, list):
            return tuple(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict):
            return tuple(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return tuple(obj for targets in target_gpus)

    return scatter_map(inputs)
utils.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def clear(self, *args):
    if len(args) == 1 and isinstance(args[0], list):
        args = args[0]

    def _clear(f):
        if not hasattr(self, f):
            return
        attr = getattr(self, f)
        if torch.is_tensor(attr):
            attr.set_()
        elif isinstance(attr, list):
            del attr[:]
        else:
            setattr(self, f, None)
    for key in args:
        _clear(key)
    return self
common.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.is_leaf
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
common.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def safeCoalesce(self, t):
        tc = t.coalesce()

        value_map = {}
        for idx, val in zip(t._indices().t(), t._values()):
            idx_tup = tuple(idx)
            if idx_tup in value_map:
                value_map[idx_tup] += val
            else:
                value_map[idx_tup] = val.clone() if torch.is_tensor(val) else val

        new_indices = sorted(list(value_map.keys()))
        new_values = [value_map[idx] for idx in new_indices]
        if t._values().ndimension() < 2:
            new_values = t._values().new(new_values)
        else:
            new_values = torch.stack(new_values)

        new_indices = t._indices().new(new_indices).t()
        tg = t.new(new_indices, new_values, t.size())

        self.assertEqual(tc._indices(), tg._indices())
        self.assertEqual(tc._values(), tg._values())

        return tg
state.py 文件源码 项目:seq2seq.pytorch 作者: eladhoffer 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def __merge_states(self, state_list, type_state='hidden'):
        if state_list is None:
            return None
        if isinstance(state_list[0], State):
            return State().from_list(state_list)
        if isinstance(state_list[0], tuple):
            return tuple([self.__merge_states(s, type_state) for s in zip(*state_list)])
        else:
            if isinstance(state_list[0], Variable) or torch.is_tensor(state_list[0]):
                if type_state == 'hidden':
                    batch_dim = 0 if state_list[0].dim() < 3 else 1
                else:
                    batch_dim = 0 if self.batch_first else 1
                return torch.cat(state_list, batch_dim)
            else:
                assert state_list[1:] == state_list[:-1]  # all items are equal
                return state_list[0]
multi_language.py 文件源码 项目:seq2seq.pytorch 作者: eladhoffer 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def create_padded_batch(max_length=100, max_tokens=None,
                        batch_first=False, sort=False,
                        pack=False, augment=False):
    def collate(seqs, sort=sort, pack=pack):
        if not torch.is_tensor(seqs[0]):
            if sort or pack:  # packing requires a sorted batch by length
                # sort by the first set
                seqs.sort(key=lambda x: len(x[0]), reverse=True)
            # TODO: for now, just the first input will be packed
            return tuple([collate(s, sort=False, pack=pack and (i == 0))
                          for i, s in enumerate(zip(*seqs))])
        return batch_sequences(seqs, max_length=max_length,
                               max_tokens=max_tokens,
                               batch_first=batch_first,
                               sort=False, pack=pack, augment=augment)
    return collate
optimizer.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, params, defaults):
        self.defaults = defaults

        if isinstance(params, Variable) or torch.is_tensor(params):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Variables or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)


问题


面经


文章

微信
公众号

扫码关注公众号