python类range()的实例源码

PartialLinear.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def __init__(self, inputsize, outputsize, bias=True):
        super(PartialLinear, self).__init__()

        # define the layer as a small network:
        pt = ParallelTable()
        pt.add(Identity()).add(LookupTable(outputsize, inputsize))
        self.network = Sequential().add(pt).add(MM(False, True))
        if bias:
            self.bias     = torch.zeros(1, outputsize)
            self.gradBias = torch.zeros(1, outputsize)
        else:
            self.bias = self.gradBias = None

        # set partition:
        self.inputsize  = inputsize
        self.outputsize = outputsize
        self.allcolumns = torch.range(0, self.outputsize-1).long()
        self.resetPartition()
        self.addBuffer = None
        self.buffer = None
test_nn.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_load_parameter_dict(self):
        l = nn.Linear(5, 5)
        block = nn.Container(
            conv=nn.Conv2d(3, 3, 3, bias=False)
        )
        net = nn.Container(
            linear1=l,
            linear2=l,
            block=block,
            empty=None,
        )
        param_dict = {
            'linear1.weight': Variable(torch.ones(5, 5)),
            'block.conv.bias': Variable(torch.range(1, 3)),
        }
        net.load_parameter_dict(param_dict)
        self.assertIs(net.linear1.weight, param_dict['linear1.weight'])
        self.assertIs(net.block.conv.bias, param_dict['block.conv.bias'])
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def assertIsOrdered(self, order, x, mxx, ixx, task):
        SIZE = 4
        if order == 'descending':
            check_order = lambda a, b: a >= b
        elif order == 'ascending':
            check_order = lambda a, b: a <= b
        else:
            error('unknown order "{}", must be "ascending" or "descending"'.format(order))

        are_ordered = True
        for j, k in product(range(SIZE), range(1, SIZE)):
            self.assertTrue(check_order(mxx[j][k-1], mxx[j][k]),
                    'torch.sort ({}) values unordered for {}'.format(order, task))

        seen = set()
        indicesCorrect = True
        size = x.size(x.dim()-1)
        for k in range(size):
            seen.clear()
            for j in range(size):
                self.assertEqual(x[k][ixx[k][j]], mxx[k][j],
                        'torch.sort ({}) indices wrong for {}'.format(order, task))
                seen.add(ixx[k][j])
            self.assertEqual(len(seen), size)
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_cat(self):
        SIZE = 10
        # 2-arg cat
        for dim in range(3):
            x = torch.rand(13, SIZE, SIZE).transpose(0, dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, dim)
            res1 = torch.cat((x, y), dim)
            self.assertEqual(res1.narrow(dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(dim, 13, 17), y, 0)

        # Check iterables
        for dim in range(3):
            x = torch.rand(13, SIZE, SIZE).transpose(0, dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, dim)
            z = torch.rand(19, SIZE, SIZE).transpose(0, dim)

            res1 = torch.cat((x, y, z), dim)
            self.assertEqual(res1.narrow(dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(dim, 13, 17), y, 0)
            self.assertEqual(res1.narrow(dim, 30, 19), z, 0)
            self.assertRaises(ValueError, lambda: torch.cat([]))
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_scatter(self):
        m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
        elems_per_row = random.randint(1, 10)
        dim = random.randrange(3)

        idx_size = [m, n, o]
        idx_size[dim] = elems_per_row
        idx = torch.LongTensor().resize_(*idx_size)
        self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)
        src = torch.Tensor().resize_(*idx_size).normal_()

        actual = torch.zeros(m, n, o).scatter_(dim, idx, src)
        expected = torch.zeros(m, n, o)
        for i in range(idx_size[0]):
            for j in range(idx_size[1]):
                for k in range(idx_size[2]):
                    ii = [i, j, k]
                    ii[dim] = idx[i,j,k]
                    expected[tuple(ii)] = src[i,j,k]
        self.assertEqual(actual, expected, 0)

        idx[0][0][0] = 34
        self.assertRaises(RuntimeError, lambda: torch.zeros(m, n, o).scatter_(dim, idx, src))
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_masked_copy(self):
        num_copy, num_dest = 3, 10
        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        mask = torch.ByteTensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0))
        dest2 = dest.clone()
        dest.masked_copy_(mask, src)
        j = 0
        for i in range(num_dest):
            if mask[i]:
                dest2[i] = src[j]
                j += 1
        self.assertEqual(dest, dest2, 0)

        # make source bigger than number of 1s in mask
        src = torch.randn(num_dest)
        dest.masked_copy_(mask, src)

        # make src smaller. this should fail
        src = torch.randn(num_copy - 1)
        with self.assertRaises(RuntimeError):
            dest.masked_copy_(mask, src)
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_deepcopy(self):
        from copy import deepcopy
        a = torch.randn(5, 5)
        b = torch.randn(5, 5)
        c = a.view(25)
        q = [a, [a.storage(), b.storage()], b, c]
        w = deepcopy(q)
        self.assertEqual(w[0], q[0], 0)
        self.assertEqual(w[1][0], q[1][0], 0)
        self.assertEqual(w[1][1], q[1][1], 0)
        self.assertEqual(w[1], q[1], 0)
        self.assertEqual(w[2], q[2], 0)

        # Check that deepcopy preserves sharing
        w[0].add_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
        self.assertEqual(w[3], c + 1)
        w[2].sub_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
test_meters.py 文件源码 项目:tnt 作者: pytorch 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def testClassErrorMeter(self):
        mtr = meter.ClassErrorMeter(topk=[1])
        output = torch.eye(3)
        if hasattr(torch, "arange"):
            target = torch.arange(0, 3)
        else:
            target = torch.range(0, 2)
        mtr.add(output, target)
        err = mtr.value()

        self.assertEqual(err, [0], "All should be correct")

        target[0] = 1
        target[1] = 0
        target[2] = 0
        mtr.add(output, target)
        err = mtr.value()
        self.assertEqual(err, [50.0], "Half should be correct")
dataset.py 文件源码 项目:sceneReco 作者: bear63 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))
            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
test_oim.py 文件源码 项目:open-reid 作者: Cysu 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_forward_backward(self):
        import torch
        import torch.nn.functional as F
        from torch.autograd import Variable
        from reid.loss import OIMLoss
        criterion = OIMLoss(3, 3, scalar=1.0, size_average=False)
        criterion.lut = torch.eye(3)
        x = Variable(torch.randn(3, 3), requires_grad=True)
        y = Variable(torch.range(0, 2).long())
        loss = criterion(x, y)
        loss.backward()
        probs = F.softmax(x)
        grads = probs.data - torch.eye(3)
        abs_diff = torch.abs(grads - x.grad.data)
        self.assertEquals(torch.log(probs).diag().sum(), -loss)
        self.assertTrue(torch.max(abs_diff) < 1e-6)
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_cat(self):
        SIZE = 10
        for dim in range(-3, 3):
            pos_dim = dim if dim >= 0 else 3 + dim
            x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim)
            z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim)

            res1 = torch.cat((x, y, z), dim)
            self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
            self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)

        x = torch.randn(20, SIZE, SIZE)
        self.assertEqual(torch.cat(torch.split(x, 7)), x)
        self.assertEqual(torch.cat(torch.chunk(x, 7)), x)

        y = torch.randn(1, SIZE, SIZE)
        z = torch.cat([x, y])
        self.assertEqual(z.size(), (21, SIZE, SIZE))

        self.assertRaises(RuntimeError, lambda: torch.cat([]))
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_deepcopy(self):
        from copy import deepcopy
        a = torch.randn(5, 5)
        b = torch.randn(5, 5)
        c = a.view(25)
        q = [a, [a.storage(), b.storage()], b, c]
        w = deepcopy(q)
        self.assertEqual(w[0], q[0], 0)
        self.assertEqual(w[1][0], q[1][0], 0)
        self.assertEqual(w[1][1], q[1][1], 0)
        self.assertEqual(w[1], q[1], 0)
        self.assertEqual(w[2], q[2], 0)

        # Check that deepcopy preserves sharing
        w[0].add_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
        self.assertEqual(w[3], c + 1)
        w[2].sub_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def test_cat(self):
        SIZE = 10
        for dim in range(-3, 3):
            pos_dim = dim if dim >= 0 else 3 + dim
            x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim)
            z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim)

            res1 = torch.cat((x, y, z), dim)
            self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
            self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)

        x = torch.randn(20, SIZE, SIZE)
        self.assertEqual(torch.cat(torch.split(x, 7)), x)
        self.assertEqual(torch.cat(torch.chunk(x, 7)), x)

        y = torch.randn(1, SIZE, SIZE)
        z = torch.cat([x, y])
        self.assertEqual(z.size(), (21, SIZE, SIZE))

        self.assertRaises(RuntimeError, lambda: torch.cat([]))
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_deepcopy(self):
        from copy import deepcopy
        a = torch.randn(5, 5)
        b = torch.randn(5, 5)
        c = a.view(25)
        q = [a, [a.storage(), b.storage()], b, c]
        w = deepcopy(q)
        self.assertEqual(w[0], q[0], 0)
        self.assertEqual(w[1][0], q[1][0], 0)
        self.assertEqual(w[1][1], q[1][1], 0)
        self.assertEqual(w[1], q[1], 0)
        self.assertEqual(w[2], q[2], 0)

        # Check that deepcopy preserves sharing
        w[0].add_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
        self.assertEqual(w[3], c + 1)
        w[2].sub_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_cat(self):
        SIZE = 10
        for dim in range(-3, 3):
            pos_dim = dim if dim >= 0 else 3 + dim
            x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim)
            z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim)

            res1 = torch.cat((x, y, z), dim)
            self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
            self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)

        x = torch.randn(20, SIZE, SIZE)
        self.assertEqual(torch.cat(torch.split(x, 7)), x)
        self.assertEqual(torch.cat(torch.chunk(x, 7)), x)

        y = torch.randn(1, SIZE, SIZE)
        z = torch.cat([x, y])
        self.assertEqual(z.size(), (21, SIZE, SIZE))

        self.assertRaises(RuntimeError, lambda: torch.cat([]))
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_deepcopy(self):
        from copy import deepcopy
        a = torch.randn(5, 5)
        b = torch.randn(5, 5)
        c = a.view(25)
        q = [a, [a.storage(), b.storage()], b, c]
        w = deepcopy(q)
        self.assertEqual(w[0], q[0], 0)
        self.assertEqual(w[1][0], q[1][0], 0)
        self.assertEqual(w[1][1], q[1][1], 0)
        self.assertEqual(w[1], q[1], 0)
        self.assertEqual(w[2], q[2], 0)

        # Check that deepcopy preserves sharing
        w[0].add_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
        self.assertEqual(w[3], c + 1)
        w[2].sub_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def test_ger(self):
        types = {
            'torch.DoubleTensor': 1e-8,
            'torch.FloatTensor': 1e-4,
        }
        for tname, _prec in types.items():
            v1 = torch.randn(100).type(tname)
            v2 = torch.randn(100).type(tname)
            res1 = torch.ger(v1, v2)
            res2 = torch.zeros(100, 100).type(tname)
            for i in range(100):
                for j in range(100):
                    res2[i, j] = v1[i] * v2[j]
            self.assertEqual(res1, res2)

        # Test 0-strided
        for tname, _prec in types.items():
            v1 = torch.randn(1).type(tname).expand(100)
            v2 = torch.randn(100).type(tname)
            res1 = torch.ger(v1, v2)
            res2 = torch.zeros(100, 100).type(tname)
            for i in range(100):
                for j in range(100):
                    res2[i, j] = v1[i] * v2[j]
            self.assertEqual(res1, res2)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_cat(self):
        SIZE = 10
        for dim in range(-3, 3):
            pos_dim = dim if dim >= 0 else 3 + dim
            x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim)
            y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim)
            z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim)

            res1 = torch.cat((x, y, z), dim)
            self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
            self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
            self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)

        x = torch.randn(20, SIZE, SIZE)
        self.assertEqual(torch.cat(torch.split(x, 7)), x)
        self.assertEqual(torch.cat(torch.chunk(x, 7)), x)

        y = torch.randn(1, SIZE, SIZE)
        z = torch.cat([x, y])
        self.assertEqual(z.size(), (21, SIZE, SIZE))

        self.assertRaises(RuntimeError, lambda: torch.cat([]))
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_masked_scatter(self):
        num_copy, num_dest = 3, 10
        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        mask = torch.ByteTensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0))
        dest2 = dest.clone()
        dest.masked_scatter_(mask, src)
        j = 0
        for i in range(num_dest):
            if mask[i]:
                dest2[i] = src[j]
                j += 1
        self.assertEqual(dest, dest2, 0)

        # make source bigger than number of 1s in mask
        src = torch.randn(num_dest)
        dest.masked_scatter_(mask, src)

        # make src smaller. this should fail
        src = torch.randn(num_copy - 1)
        with self.assertRaises(RuntimeError):
            dest.masked_scatter_(mask, src)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_deepcopy(self):
        from copy import deepcopy
        a = torch.randn(5, 5)
        b = torch.randn(5, 5)
        c = a.view(25)
        q = [a, [a.storage(), b.storage()], b, c]
        w = deepcopy(q)
        self.assertEqual(w[0], q[0], 0)
        self.assertEqual(w[1][0], q[1][0], 0)
        self.assertEqual(w[1][1], q[1][1], 0)
        self.assertEqual(w[1], q[1], 0)
        self.assertEqual(w[2], q[2], 0)

        # Check that deepcopy preserves sharing
        w[0].add_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
        self.assertEqual(w[3], c + 1)
        w[2].sub_(1)
        for i in range(a.numel()):
            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_serialization_backwards_compat(self):
        a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
        b = [a[i % 2] for i in range(4)]
        b += [a[0].storage()]
        b += [a[0].storage()[1:4]]
        path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
        c = torch.load(path)
        self.assertEqual(b, c, 0)
        self.assertTrue(isinstance(c[0], torch.FloatTensor))
        self.assertTrue(isinstance(c[1], torch.FloatTensor))
        self.assertTrue(isinstance(c[2], torch.FloatTensor))
        self.assertTrue(isinstance(c[3], torch.FloatTensor))
        self.assertTrue(isinstance(c[4], torch.FloatStorage))
        c[0].fill_(10)
        self.assertEqual(c[0], c[2], 0)
        self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
        c[1].fill_(20)
        self.assertEqual(c[1], c[3], 0)
        self.assertEqual(c[4][1:4], c[5], 0)
dataset.py 文件源码 项目:sceneReco 作者: yijiuzai 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))
            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
DCN.py 文件源码 项目:DCN 作者: alexnowakvila 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def eliminate_rows(self, prob_sc, ind, phis):
        """ eliminate rows of phis and prob_matrix scale """
        length = prob_sc.size()[1]
        mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
        rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
                .expand_as(mask)).
                type(dtype))
        ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
        # permute prob_sc
        m = mask.unsqueeze(2).expand_as(prob_sc)
        mm = m.clone()
        mm[:, :, 1:] = 0
        prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
                   ind_sc.unsqueeze(2).expand_as(prob_sc)))
        # compose permutations
        ind = torch.gather(ind, 1, ind_sc)
        active = torch.gather(1-mask, 1, ind_sc)
        # permute phis
        active1 = active.unsqueeze(2).expand_as(phis)
        ind1 = ind.unsqueeze(2).expand_as(phis)
        active2 = active.unsqueeze(1).expand_as(phis)
        ind2 = ind.unsqueeze(1).expand_as(phis)
        phis_out = torch.gather(phis, 1, ind1) * active1
        phis_out = torch.gather(phis_out, 2, ind2) * active2
        return prob_sc, ind, phis_out, active


问题


面经


文章

微信
公众号

扫码关注公众号