python类randperm()的实例源码

demo.py 文件源码 项目:efficient_densenet_pytorch 作者: gpleiss 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _make_dataloaders(train_set, valid_set, test_set, train_size, valid_size, batch_size):
    # Split training into train and validation
    indices = torch.randperm(len(train_set))
    train_indices = indices[:len(indices)-valid_size][:train_size or None]
    valid_indices = indices[len(indices)-valid_size:] if valid_size else None

    train_loader = torch.utils.data.DataLoader(train_set, pin_memory=True, batch_size=batch_size,
                                               sampler=SubsetRandomSampler(train_indices))
    test_loader = torch.utils.data.DataLoader(test_set, pin_memory=True, batch_size=batch_size)
    if valid_size:
        valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
                                                   sampler=SubsetRandomSampler(valid_indices))
    else:
        valid_loader = None

    return train_loader, valid_loader, test_loader
trainer.py 文件源码 项目:treelstm.pytorch 作者: dasguptar 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def train(self, dataset):
        self.model.train()
        self.optimizer.zero_grad()
        total_loss = 0.0
        indices = torch.randperm(len(dataset))
        for idx in tqdm(range(len(dataset)),desc='Training epoch ' + str(self.epoch + 1) + ''):
            ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
            linput, rinput = Var(lsent), Var(rsent)
            target = Var(map_label_to_target(label, dataset.num_classes))
            if self.args.cuda:
                linput, rinput = linput.cuda(), rinput.cuda()
                target = target.cuda()
            output = self.model(ltree, linput, rtree, rinput)
            loss = self.criterion(output, target)
            total_loss += loss.data[0]
            loss.backward()
            if idx % self.args.batchsize == 0 and idx > 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
        self.epoch += 1
        return total_loss / len(dataset)

    # helper function for testing
SpatialConvolutionMap.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def random(nin, nout, nto):
            nker = nto * nout
            tbl = torch.Tensor(nker, 2)
            fi = torch.randperm(nin)
            frcntr = 0
            nfi = math.floor(nin / nto) # number of distinct nto chunks
            totbl = tbl.select(1, 1)
            frtbl = tbl.select(1, 0)
            fitbl = fi.narrow(0, 0, (nfi * nto)) # part of fi that covers distinct chunks
            ufrtbl = frtbl.unfold(0, nto, nto)
            utotbl = totbl.unfold(0, nto, nto)
            ufitbl = fitbl.unfold(0, nto, nto)

            # start fill_ing frtbl
            for i in range(nout): # fro each unit in target map
                ufrtbl.select(0, i).copy_(ufitbl.select(0, frcntr))
                frcntr += 1
                if frcntr-1 == nfi: # reset fi
                    fi.copy_(torch.randperm(nin))
                    frcntr = 1

            for tocntr in range(utotbl.size(0)):
                utotbl.select(0, tocntr).fill_(tocntr)

            return tbl
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy).long()
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
shuffledataset.py 文件源码 项目:tnt 作者: pytorch 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def resample(self, seed=None):
        """Resample the dataset.

        Args:
            seed (int, optional): Seed for resampling. By default no seed is
            used.
        """
        if seed is not None:
            gen = torch.manual_seed(seed)
        else:
            gen = torch.default_generator

        if self.replacement:
            self.perm = torch.LongTensor(len(self)).random_(
                len(self.dataset), generator=gen)
        else:
            self.perm = torch.randperm(
                len(self.dataset), generator=gen).narrow(0, 0, len(self))
SpatialConvolutionMap.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def random(nin, nout, nto):
            nker = nto * nout
            tbl = torch.Tensor(nker, 2)
            fi = torch.randperm(nin)
            frcntr = 0
            nfi = math.floor(nin / nto)  # number of distinct nto chunks
            totbl = tbl.select(1, 1)
            frtbl = tbl.select(1, 0)
            fitbl = fi.narrow(0, 0, (nfi * nto))  # part of fi that covers distinct chunks
            ufrtbl = frtbl.unfold(0, nto, nto)
            utotbl = totbl.unfold(0, nto, nto)
            ufitbl = fitbl.unfold(0, nto, nto)

            # start fill_ing frtbl
            for i in range(nout):  # fro each unit in target map
                ufrtbl.select(0, i).copy_(ufitbl.select(0, frcntr))
                frcntr += 1
                if frcntr - 1 == nfi:  # reset fi
                    fi.copy_(torch.randperm(nin))
                    frcntr = 1

            for tocntr in range(utotbl.size(0)):
                utotbl.select(0, tocntr).fill_(tocntr)

            return tbl
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
SpatialConvolutionMap.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def random(nin, nout, nto):
            nker = nto * nout
            tbl = torch.Tensor(nker, 2)
            fi = torch.randperm(nin)
            frcntr = 0
            nfi = math.floor(nin / nto)  # number of distinct nto chunks
            totbl = tbl.select(1, 1)
            frtbl = tbl.select(1, 0)
            fitbl = fi.narrow(0, 0, (nfi * nto))  # part of fi that covers distinct chunks
            ufrtbl = frtbl.unfold(0, nto, nto)
            utotbl = totbl.unfold(0, nto, nto)
            ufitbl = fitbl.unfold(0, nto, nto)

            # start fill_ing frtbl
            for i in range(nout):  # fro each unit in target map
                ufrtbl.select(0, i).copy_(ufitbl.select(0, frcntr))
                frcntr += 1
                if frcntr - 1 == nfi:  # reset fi
                    fi.copy_(torch.randperm(nin))
                    frcntr = 1

            for tocntr in range(utotbl.size(0)):
                utotbl.select(0, tocntr).fill_(tocntr)

            return tbl
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)

    # Fill idx with valid indices.
distributed.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = list(torch.randperm(len(self.dataset), generator=g))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        offset = self.num_samples * self.rank
        indices = indices[offset:offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)
SpatialConvolutionMap.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def random(nin, nout, nto):
            nker = nto * nout
            tbl = torch.Tensor(nker, 2)
            fi = torch.randperm(nin)
            frcntr = 0
            nfi = math.floor(nin / nto)  # number of distinct nto chunks
            totbl = tbl.select(1, 1)
            frtbl = tbl.select(1, 0)
            fitbl = fi.narrow(0, 0, (nfi * nto))  # part of fi that covers distinct chunks
            ufrtbl = frtbl.unfold(0, nto, nto)
            utotbl = totbl.unfold(0, nto, nto)
            ufitbl = fitbl.unfold(0, nto, nto)

            # start fill_ing frtbl
            for i in range(nout):  # fro each unit in target map
                ufrtbl.select(0, i).copy_(ufitbl.select(0, frcntr))
                frcntr += 1
                if frcntr - 1 == nfi:  # reset fi
                    fi.copy_(torch.randperm(nin))
                    frcntr = 1

            for tocntr in range(utotbl.size(0)):
                utotbl.select(0, tocntr).fill_(tocntr)

            return tbl
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
pytorch_run.py 文件源码 项目:pytorch-avitm 作者: hyqneuron 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def train():
    for epoch in xrange(args.num_epoch):
        all_indices = torch.randperm(tensor_tr.size(0)).split(args.batch_size)
        loss_epoch = 0.0
        model.train()                   # switch to training mode
        for batch_indices in all_indices:
            if not args.nogpu: batch_indices = batch_indices.cuda()
            input = Variable(tensor_tr[batch_indices])
            recon, loss = model(input, compute_loss=True)
            # optimize
            optimizer.zero_grad()       # clear previous gradients
            loss.backward()             # backprop
            optimizer.step()            # update parameters
            # report
            loss_epoch += loss.data[0]    # add loss to loss_epoch
        if epoch % 5 == 0:
            print('Epoch {}, loss={}'.format(epoch, loss_epoch / len(all_indices)))
trainer.py 文件源码 项目:treelstm-pytorch 作者: pklfz 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def train(self, dataset):
        self.model.train()
        self.optimizer.zero_grad()
        loss, k = 0.0, 0
        indices = torch.randperm(len(dataset))
        for idx in tqdm(range(len(dataset)), desc='Training epoch ' + str(self.epoch + 1) + ''):
            ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
            linput, rinput = Var(lsent), Var(rsent)
            target = Var(map_label_to_target(label, dataset.num_classes))
            if self.args.cuda:
                linput, rinput = linput.cuda(), rinput.cuda()
                target = target.cuda()
            output = self.model(ltree, linput, rtree, rinput)
            err = self.criterion(output, target)
            loss += err.data[0]
            err.backward()
            k += 1
            if k % self.args.batchsize == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
        self.epoch += 1
        return loss / len(dataset)

    # helper function for testing
distortion_transforms.py 文件源码 项目:torchsample 作者: ncullen93 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __call__(self, *inputs):
        outputs = []
        for idx, _input in enumerate(inputs):
            size = _input.size()
            img_height = size[1]
            img_width = size[2]

            x_blocks = int(img_height/self.blocksize) # number of x blocks
            y_blocks = int(img_width/self.blocksize)
            ind = th.randperm(x_blocks*y_blocks)

            new = th.zeros(_input.size())
            count = 0
            for i in range(x_blocks):
                for j in range (y_blocks):
                    row = int(ind[count] / x_blocks)
                    column = ind[count] % x_blocks
                    new[:, i*self.blocksize:(i+1)*self.blocksize, j*self.blocksize:(j+1)*self.blocksize] = \
                    _input[:, row*self.blocksize:(row+1)*self.blocksize, column*self.blocksize:(column+1)*self.blocksize]
                    count += 1
            outputs.append(new)
        return outputs if idx > 1 else outputs[0]
distributed.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = list(torch.randperm(len(self.dataset), generator=g))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        offset = self.num_samples * self.rank
        indices = indices[offset:offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)
SpatialConvolutionMap.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def random(nin, nout, nto):
            nker = nto * nout
            tbl = torch.Tensor(nker, 2)
            fi = torch.randperm(nin)
            frcntr = 0
            nfi = math.floor(nin / nto)  # number of distinct nto chunks
            totbl = tbl.select(1, 1)
            frtbl = tbl.select(1, 0)
            fitbl = fi.narrow(0, 0, (nfi * nto))  # part of fi that covers distinct chunks
            ufrtbl = frtbl.unfold(0, nto, nto)
            utotbl = totbl.unfold(0, nto, nto)
            ufitbl = fitbl.unfold(0, nto, nto)

            # start fill_ing frtbl
            for i in range(nout):  # fro each unit in target map
                ufrtbl.select(0, i).copy_(ufitbl.select(0, frcntr))
                frcntr += 1
                if frcntr - 1 == nfi:  # reset fi
                    fi.copy_(torch.randperm(nin))
                    frcntr = 1

            for tocntr in range(utotbl.size(0)):
                utotbl.select(0, tocntr).fill_(tocntr)

            return tbl
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def test_index_copy(self):
        num_copy, num_dest = 3, 20
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].copy_(src[i])
        self.assertEqual(dest, dest2, 0)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_copy_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = src[i]
        self.assertEqual(dest, dest2, 0)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def test_index_add(self):
        num_copy, num_dest = 3, 3
        dest = torch.randn(num_dest, 4, 5)
        src = torch.randn(num_copy, 4, 5)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]].add_(src[i])
        self.assertEqual(dest, dest2)

        dest = torch.randn(num_dest)
        src = torch.randn(num_copy)
        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
        dest2 = dest.clone()
        dest.index_add_(0, idx, src)
        for i in range(idx.size(0)):
            dest2[idx[i]] = dest2[idx[i]] + src[i]
        self.assertEqual(dest, dest2)
misc.py 文件源码 项目:verb-attributes 作者: uwnlp 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def cosine_ranking_loss(input_data, ctx, margin=0.1):
    """
    :param input_data: [batch_size, 300] tensor of predictions
    :param ctx: [batch_size, 300] tensor of ground truths
    :param margin: Difference between them
    :return: 
    """
    normed = _normalize(input_data)
    ctx_normed = _normalize(ctx)
    shuff_inds = torch.randperm(normed.size(0))
    if ctx.is_cuda:
        shuff_inds = shuff_inds.cuda()
    shuff = ctx_normed[shuff_inds]

    correct_contrib = torch.sum(normed * ctx_normed, 1).squeeze()
    incorrect_contrib = torch.sum(normed * shuff, 1).squeeze()

    # similarity = torch.mm(normed, ctx_normed.t()) #[predictions, gts]
    # correct_contrib = similarity.diag()
    # incorrect_contrib = incorrect_contrib.sum(1).squeeze()/(incorrect_contrib.size(1)-1.0)
    #
    cost = (0.1 + incorrect_contrib-correct_contrib).clamp(min=0)

    return cost, correct_contrib, incorrect_contrib
trainer.py 文件源码 项目:TreeLSTMSentiment 作者: ttpro1995 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def train(self, dataset):
        self.model.train()
        self.optimizer.zero_grad()
        loss, k = 0.0, 0
        indices = torch.randperm(len(dataset))
        for idx in tqdm(xrange(len(dataset)),desc='Training epoch '+str(self.epoch+1)+''):
            ltree,lsent,rtree,rsent,label = dataset[indices[idx]]
            linput, rinput = Var(lsent), Var(rsent)
            target = Var(map_label_to_target(label,dataset.num_classes))
            if self.args.cuda:
                linput, rinput = linput.cuda(), rinput.cuda()
                target = target.cuda()
            output = self.model(ltree,linput,rtree,rinput)
            err = self.criterion(output, target)
            loss += err.data[0]
            err.backward()
            k += 1
            if k%self.args.batchsize==0:
                self.optimizer.step()
                self.optimizer.zero_grad()
        self.epoch += 1
        return loss/len(dataset)

    # helper function for testing
preprocess.py 文件源码 项目:convNet.pytorch 作者: eladhoffer 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def __call__(self, img):
        if self.transforms is None:
            return img
        order = torch.randperm(len(self.transforms))
        for i in order:
            img = self.transforms[i](img)
        return img
sampler.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __iter__(self):
        return iter(torch.randperm(self.num_samples).long())
test_utils.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def __iter__(self):
        for i in range(10):
            yield torch.randn(2, 10), torch.randperm(10)[:2]
test_dataloader.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_len(self):
        source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
        self.assertEqual(len(source), 15)
test_dataloader.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def setUp(self):
        self.data = torch.randn(100, 2, 3, 5)
        self.labels = torch.randperm(50).repeat(2)
        self.dataset = TensorDataset(self.data, self.labels)
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
        for i in range(1 if dim == 0 else m):
            for j in range(1 if dim == 1 else n):
                for k in range(1 if dim == 2 else o):
                    ii = [i, j, k]
                    ii[dim] = slice(0, idx.size(dim)+1)
                    idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]


问题


面经


文章

微信
公众号

扫码关注公众号