python类sort()的实例源码

test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 49 收藏 0 点赞 0 评论 0
def assertIsOrdered(self, order, x, mxx, ixx, task):
        SIZE = 4
        if order == 'descending':
            check_order = lambda a, b: a >= b
        elif order == 'ascending':
            check_order = lambda a, b: a <= b
        else:
            error('unknown order "{}", must be "ascending" or "descending"'.format(order))

        are_ordered = True
        for j, k in product(range(SIZE), range(1, SIZE)):
            self.assertTrue(check_order(mxx[j][k-1], mxx[j][k]),
                    'torch.sort ({}) values unordered for {}'.format(order, task))

        seen = set()
        indicesCorrect = True
        size = x.size(x.dim()-1)
        for k in range(size):
            seen.clear()
            for j in range(size):
                self.assertEqual(x[k][ixx[k][j]], mxx[k][j],
                        'torch.sort ({}) indices wrong for {}'.format(order, task))
                seen.add(ixx[k][j])
            self.assertEqual(len(seen), size)
coref.py 文件源码 项目:allennlp 作者: allenai 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _prune_and_sort_spans(mention_scores: torch.FloatTensor,
                              num_spans_to_keep: int) -> torch.IntTensor:
        """
        The indices of the top-k scoring spans according to span_scores. We return the
        indices in their original order, not ordered by score, so that we can rely on
        the ordering to consider the previous k spans as antecedents for each span later.

        Parameters
        ----------
        mention_scores : ``torch.FloatTensor``, required.
            The mention score for every candidate, with shape (batch_size, num_spans, 1).
        num_spans_to_keep : ``int``, required.
            The number of spans to keep when pruning.
        Returns
        -------
        top_span_indices : ``torch.IntTensor``, required.
            The indices of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep).
        """
        # Shape: (batch_size, num_spans_to_keep, 1)
        _, top_span_indices = mention_scores.topk(num_spans_to_keep, 1)
        top_span_indices, _ = torch.sort(top_span_indices, 1)

        # Shape: (batch_size, num_spans_to_keep)
        top_span_indices = top_span_indices.squeeze(-1)
        return top_span_indices
utils.py 文件源码 项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def compute_precision_mapping(pt):
    thresh_all = []
    prec_all = []
    for jj in xrange(1000):
        thresh = pt['details']['score'][:, jj]
        prec = pt['details']['precision'][:, jj]
        ind = np.argsort(thresh); # thresh, ind = torch.sort(thresh)
        thresh = thresh[ind];
        indexes = np.unique(thresh, return_index=True)[1]
        indexes = np.sort(indexes);
        thresh = thresh[indexes]

        thresh = np.vstack((min(-1000, min(thresh) - 1), thresh[:, np.newaxis], max(1000, max(thresh) + 1)));

        prec = prec[ind];
        for i in xrange(1, len(prec)):
            prec[i] = max(prec[i], prec[i - 1]);
        prec = prec[indexes]

        prec = np.vstack((prec[0], prec[:, np.newaxis], prec[-1]));
        thresh_all.append(thresh)
        prec_all.append(prec)
    precision_score = {'thresh': thresh_all, "prec": prec_all}
    return precision_score
utils.py 文件源码 项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def compute_precision_score_mapping(thresh, prec, score):
    ind = np.argsort(thresh); # thresh, ind = torch.sort(thresh)
    thresh = thresh[ind];
    indexes = np.unique(thresh, return_index=True)[1]
    indexes = np.sort(indexes);
    thresh = thresh[indexes]

    thresh = np.vstack((min(-1000, min(thresh) - 1), thresh[:, np.newaxis], max(1000, max(thresh) + 1)));

    prec = prec[ind];
    for i in xrange(1, len(prec)):
        prec[i] = max(prec[i], prec[i - 1]);
    prec = prec[indexes]

    prec = np.vstack((prec[0], prec[:, np.newaxis], prec[-1]));

    f = interp1d(thresh[:, 0], prec[:, 0])
    val = f(score)
    return val
segment.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def backward(ctx, grad_outputs):
        size = grad_outputs.size(1)
        segm_sorted = torch.sort(ctx.rev_segm_sorted)[1]
        grad_outputs = torch.index_select(grad_outputs, 0, segm_sorted)

        offset = [ctx.num_zeros]

        def backward_segment(l, n):
            segment_grad = grad_outputs.narrow(0, offset[0], n // l)
            if l > 1:
                segment_grad = _MyMax.backward(ctx.maxes[l], segment_grad)[0].view(n, size)
            offset[0] += n // l
            return segment_grad

        segment_grads = [backward_segment(l, n) for l, n in enumerate(ctx.num_lengths) if n > 0]
        grads = torch.cat(segment_grads, 0)
        rev_length_sorted = torch.sort(ctx.lengths_sorted)[1]
        grads = torch.index_select(grads, 0, rev_length_sorted)

        return grads, None, None, None
Dict.py 文件源码 项目:bandit-nmt 作者: khanhptnk 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def prune(self, size):
        if size >= self.size():
            return self

        # Only keep the `size` most frequent entries.
        freq = torch.Tensor(
                [self.frequencies[i] for i in range(len(self.frequencies))])
        _, idx = torch.sort(freq, 0, True)

        newDict = Dict()

        # Add special entries in all cases.
        for i in self.special:
            newDict.addSpecial(self.idxToLabel[i])

        for i in idx[:size]:
            newDict.add(self.idxToLabel[i])

        return newDict

    # Convert `labels` to indices. Use `unkWord` if not found.
    # Optionally insert `bosWord` at the beginning and `eosWord` at the .
utils.py 文件源码 项目:pytorch-caffe-darknet-convert 作者: marvis 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def nms(boxes, nms_thresh):
    if len(boxes) == 0:
        return boxes

    det_confs = torch.zeros(len(boxes))
    for i in range(len(boxes)):
        det_confs[i] = 1-boxes[i][4]                

    _,sortIds = torch.sort(det_confs)
    out_boxes = []
    for i in range(len(boxes)):
        box_i = boxes[sortIds[i]]
        if box_i[4] > 0:
            out_boxes.append(box_i)
            for j in range(i+1, len(boxes)):
                box_j = boxes[sortIds[j]]
                if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
                    #print(box_i, box_j, bbox_iou(box_i, box_j, x1y1x2y2=False))
                    box_j[4] = 0
    return out_boxes
Dict.py 文件源码 项目:NeuralMT 作者: hlt-mt 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def prune(self, size):
        "Return a new dictionary with the `size` most frequent entries."
        if size >= self.size():
            return self

        # Only keep the `size` most frequent entries.
        freq = torch.Tensor(
                [self.frequencies[i] for i in range(len(self.frequencies))])
        _, idx = torch.sort(freq, 0, True)

        newDict = Dict()
        newDict.lower = self.lower

        # Add special entries in all cases.
        for i in self.special:
            newDict.addSpecial(self.idxToLabel[i])

        for i in idx[:size]:
            newDict.add(self.idxToLabel[i])

        return newDict
ssn_ops.py 文件源码 项目:action-detection 作者: yjxiong 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def forward(ctx, pred, labels, is_positive, ohem_ratio, group_size):
        n_sample = pred.size()[0]
        assert n_sample == len(labels), "mismatch between sample size and label size"
        losses = torch.zeros(n_sample)
        slopes = torch.zeros(n_sample)
        for i in range(n_sample):
            losses[i] = max(0, 1 - is_positive * pred[i, labels[i] - 1])
            slopes[i] = -is_positive if losses[i] != 0 else 0

        losses = losses.view(-1, group_size).contiguous()
        sorted_losses, indices = torch.sort(losses, dim=1, descending=True)
        keep_num = int(group_size * ohem_ratio)
        loss = torch.zeros(1).cuda()
        for i in range(losses.size(0)):
            loss += sorted_losses[i, :keep_num].sum()
        ctx.loss_ind = indices[:, :keep_num]
        ctx.labels = labels
        ctx.slopes = slopes
        ctx.shape = pred.size()
        ctx.group_size = group_size
        ctx.num_group = losses.size(0)
        return loss
pooling.py 文件源码 项目:wildcat.pytorch 作者: durandtibo 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def forward(self, input):
        batch_size = input.size(0)
        num_channels = input.size(1)
        h = input.size(2)
        w = input.size(3)

        n = h * w  # number of regions

        kmax = self.get_positive_k(self.kmax, n)
        kmin = self.get_positive_k(self.kmin, n)

        sorted, indices = input.new(), input.new().long()
        torch.sort(input.view(batch_size, num_channels, n), dim=2, descending=True, out=(sorted, indices))

        self.indices_max = indices.narrow(2, 0, kmax)
        output = sorted.narrow(2, 0, kmax).sum(2).div_(kmax)

        if kmin > 0 and self.alpha is not 0:
            self.indices_min = indices.narrow(2, n - kmin, kmin)
            output.add_(sorted.narrow(2, n - kmin, kmin).sum(2).mul_(self.alpha / kmin)).div_(2)

        self.save_for_backward(input)
        return output.view(batch_size, num_channels)
util.py 文件源码 项目:wildcat.pytorch 作者: durandtibo 项目源码 文件源码 阅读 47 收藏 0 点赞 0 评论 0
def value(self):
        """Returns the model's average precision for each class
        Return:
            ap (FloatTensor): 1xK tensor, with avg precision for each class k
        """

        if self.scores.numel() == 0:
            return 0
        ap = torch.zeros(self.scores.size(1))
        rg = torch.arange(1, self.scores.size(0)).float()

        # compute average precision for each class
        for k in range(self.scores.size(1)):
            # sort scores
            scores = self.scores[:, k]
            targets = self.targets[:, k]

            # compute average precision
            ap[k] = AveragePrecisionMeter.average_precision(scores, targets, self.difficult_examples)
        return ap
utils.py 文件源码 项目:SeqMatchSeq 作者: pcgreat 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def MAP(ground_label: torch.FloatTensor, predict_label: torch.FloatTensor):
    map = 0
    map_idx = 0
    extracted = {}

    for idx_, glab in enumerate(ground_label):
        if ground_label[idx_] != 0:
            extracted[idx_] = 1

    val, key = torch.sort(predict_label, 0, True)
    for i, idx_ in enumerate(key):
        if idx_ in extracted:
            map_idx += 1
            map += map_idx / (i + 1)

    assert (map_idx != 0)
    map = map / map_idx
    return map
utils.py 文件源码 项目:SeqMatchSeq 作者: pcgreat 项目源码 文件源码 阅读 43 收藏 0 点赞 0 评论 0
def MRR(ground_label: torch.FloatTensor, predict_label: torch.FloatTensor):
    mrr = 0
    map_idx = 0
    extracted = {}

    for idx_, glab in enumerate(ground_label):
        if ground_label[idx_] != 0:
            extracted[idx_] = 1

    val, key = torch.sort(predict_label, 0, True)
    for i, idx_ in enumerate(key):
        if idx_ in extracted:
            mrr = 1.0 / (i + 1)
            break

    assert (mrr != 0)
    return mrr
Dict.py 文件源码 项目:alpha-dimt-icmlws 作者: sotetsuk 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def prune(self, size):
        if size >= self.size():
            return self

        # Only keep the `size` most frequent entries.
        freq = torch.Tensor(
                [self.frequencies[i] for i in range(len(self.frequencies))])
        _, idx = torch.sort(freq, 0, True)

        newDict = Dict()
        newDict.lower = self.lower

        # Add special entries in all cases.
        for i in self.special:
            newDict.addSpecial(self.idxToLabel[i])

        for i in idx[:size]:
            newDict.add(self.idxToLabel[i])

        return newDict

    # Convert `labels` to indices. Use `unkWord` if not found.
    # Optionally insert `bosWord` at the beginning and `eosWord` at the .
baseline_crf.py 文件源码 项目:imSitu 作者: my89 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def eval_model(dataset_loader, encoding, model):
    model.eval()
    print "evaluating model..."
    top1 = imSituTensorEvaluation(1, 3, encoding)
    top5 = imSituTensorEvaluation(5, 3, encoding)

    mx = len(dataset_loader) 
    for i, (index, input, target) in enumerate(dataset_loader):
      print "{}/{} batches\r".format(i+1,mx) ,
      input_var = torch.autograd.Variable(input.cuda(), volatile = True)
      target_var = torch.autograd.Variable(target.cuda(), volatile = True)
      (scores,predictions)  = model.forward_max(input_var)
      (s_sorted, idx) = torch.sort(scores, 1, True)
      top1.add_point(target, predictions.data, idx.data)
      top5.add_point(target, predictions.data, idx.data)

    print "\ndone."
    return (top1, top5)
utils.py 文件源码 项目:pytorch-yolo2 作者: marvis 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def nms(boxes, nms_thresh):
    if len(boxes) == 0:
        return boxes

    det_confs = torch.zeros(len(boxes))
    for i in range(len(boxes)):
        det_confs[i] = 1-boxes[i][4]                

    _,sortIds = torch.sort(det_confs)
    out_boxes = []
    for i in range(len(boxes)):
        box_i = boxes[sortIds[i]]
        if box_i[4] > 0:
            out_boxes.append(box_i)
            for j in range(i+1, len(boxes)):
                box_j = boxes[sortIds[j]]
                if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
                    #print(box_i, box_j, bbox_iou(box_i, box_j, x1y1x2y2=False))
                    box_j[4] = 0
    return out_boxes
TripletLoss.py 文件源码 项目:pytorch-PersonReID 作者: huaijin-chen 项目源码 文件源码 阅读 60 收藏 0 点赞 0 评论 0
def forward(self, anchor, positive, negative):
        #eucl distance
        #dist = torch.sum( (anchor - positive) ** 2 - (anchor - negative) ** 2, dim=1)\
        #        + self.margin

        if self.dist_type == 0:
            dist_p = F.pairwise_distance(anchor ,positive)
            dist_n = F.pairwise_distance(anchor ,negative)
        if self.dist_type == 1:
            dist_p = cosine_similarity(anchor, positive)
            disp_n = cosine_similarity(anchor, negative)


        dist_hinge = torch.clamp(dist_p - dist_n + self.margin, min=0.0)
        if self.use_ohem:
            v, idx = torch.sort(dist_hinge,descending=True)
            loss = torch.mean(v[0:self.ohem_bs])
        else:
            loss = torch.mean(dist_hinge)

        return loss
utils.py 文件源码 项目:pretrained-models.pytorch 作者: Cadene 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def value(self):
        """Returns the model's average precision for each class
        Return:
            ap (FloatTensor): 1xK tensor, with avg precision for each class k
        """

        if self.scores.numel() == 0:
            return 0
        ap = torch.zeros(self.scores.size(1))
        rg = torch.arange(1, self.scores.size(0)).float()

        # compute average precision for each class
        for k in range(self.scores.size(1)):
            # sort scores
            scores = self.scores[:, k]
            targets = self.targets[:, k]

            # compute average precision
            ap[k] = AveragePrecisionMeter.average_precision(scores, targets, self.difficult_examples)
        return ap
DCN.py 文件源码 项目:DCN 作者: alexnowakvila 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def reindex_target(self, target, e):
        """ Reindex target by embedding to be coherent. We have to invert
        a permutation and add some padding to do it correctly. """
        ind = torch.sort(e, 1)[1].squeeze()
        # target = new_target(ind) -> new_target = target(ind_inv)
        # invert permutation
        ind_inv = torch.sort(ind, 1)[1]
        mask = (target >= 0).astype(float)
        target = target * mask
        for example in xrange(self.batch_size):
            tar = target[example].astype(int)
            ind_inv_n = ind_inv[example].data.cpu().numpy()
            tar = ind_inv_n[tar]
            tar_aux = tar[np.where(mask[example] == 1)[0]]
            tar[:tar_aux.shape[0]] = tar_aux
            target[example] = tar
        target = target * mask
        return target
DCN.py 文件源码 项目:DCN 作者: alexnowakvila 项目源码 文件源码 阅读 44 收藏 0 点赞 0 评论 0
def eliminate_rows(self, prob_sc, ind, phis):
        """ eliminate rows of phis and prob_matrix scale """
        length = prob_sc.size()[1]
        mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
        rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
                .expand_as(mask)).
                type(dtype))
        ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
        # permute prob_sc
        m = mask.unsqueeze(2).expand_as(prob_sc)
        mm = m.clone()
        mm[:, :, 1:] = 0
        prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
                   ind_sc.unsqueeze(2).expand_as(prob_sc)))
        # compose permutations
        ind = torch.gather(ind, 1, ind_sc)
        active = torch.gather(1-mask, 1, ind_sc)
        # permute phis
        active1 = active.unsqueeze(2).expand_as(phis)
        ind1 = ind.unsqueeze(2).expand_as(phis)
        active2 = active.unsqueeze(1).expand_as(phis)
        ind2 = ind.unsqueeze(1).expand_as(phis)
        phis_out = torch.gather(phis, 1, ind1) * active1
        phis_out = torch.gather(phis_out, 2, ind2) * active2
        return prob_sc, ind, phis_out, active
Logger.py 文件源码 项目:DCN 作者: alexnowakvila 项目源码 文件源码 阅读 44 收藏 0 点赞 0 评论 0
def plot_norm_points(self, Inputs_N, e, Perms, scales, fig=1):
        input = Inputs_N[0][0].data.cpu().numpy()
        e = torch.sort(e, 1)[0][0].data.cpu().numpy()
        Perms = [perm[0].data.cpu().numpy() for perm in Perms]
        plt.figure(fig)
        plt.clf()
        ee = e.copy()
        for i, perm in enumerate(Perms):
            plt.subplot(1, len(Perms), i + 1)
            colors = cm.rainbow(np.linspace(0, 1, 2 ** (scales - i)))
            perm = perm[np.where(perm > 0)[0]] - 1
            points = input[perm]
            e_scale = ee[perm]
            for node in xrange(2 ** (scales - i)):
                ind = np.where(e_scale == node)[0]
                pts = points[ind]
                plt.scatter(pts[:, 0], pts[:, 1], c=colors[node])
            ee //= 2
        path = os.path.join(self.path, 'visualize_example.png')
        plt.savefig(path)
predict.py 文件源码 项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码 阅读 53 收藏 0 点赞 0 评论 0
def prepare_batch(xs, lens, gpu=True):
    lens, idx = torch.sort(lens, 0, True)
    _, ridx = torch.sort(idx, 0)
    idx_exp = idx.unsqueeze(0).unsqueeze(-1).expand_as(xs)
    xs = torch.gather(xs, 1, idx_exp)

    xs = Variable(xs, volatile=True)
    lens = Variable(lens, volatile=True)
    ridx = Variable(ridx, volatile=True)

    if gpu:
        xs = xs.cuda()
        lens = lens.cuda()
        ridx = ridx.cuda()

    return xs, lens, ridx
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 114 收藏 0 点赞 0 评论 0
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            res1val, res1ind = torch.median(x)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size+1)/2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)
            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(res2val, res2ind, x)
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val[0], res2val[ind], 0)
            self.assertEqual(res1ind[0], res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
aucmeter.py 文件源码 项目:tnt 作者: pytorch 项目源码 文件源码 阅读 49 收藏 0 点赞 0 评论 0
def value(self):
        # case when number of elements added are 0
        if self.scores.shape[0] == 0:
            return 0.5

        # sorting the arrays
        scores, sortind = torch.sort(torch.from_numpy(
            self.scores), dim=0, descending=True)
        scores = scores.numpy()
        sortind = sortind.numpy()

        # creating the roc curve
        tpr = np.zeros(shape=(scores.size + 1), dtype=np.float64)
        fpr = np.zeros(shape=(scores.size + 1), dtype=np.float64)

        for i in range(1, scores.size + 1):
            if self.targets[sortind[i - 1]] == 1:
                tpr[i] = tpr[i - 1] + 1
                fpr[i] = fpr[i - 1]
            else:
                tpr[i] = tpr[i - 1]
                fpr[i] = fpr[i - 1] + 1

        tpr /= (self.targets.sum() * 1.0)
        fpr /= ((self.targets - 1.0).sum() * -1.0)

        # calculating area under curve using trapezoidal rule
        n = tpr.shape[0]
        h = fpr[1:n] - fpr[0:n - 1]
        sum_h = np.zeros(fpr.shape)
        sum_h[0:n - 1] = h
        sum_h[1:n] += h
        area = (sum_h * tpr).sum() / 2.0

        return (area, tpr, fpr)
utils.py 文件源码 项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def compute_precision_score_mapping_torch(thresh, prec, score):
    thresh, ind_thresh = torch.sort(torch.from_numpy(thresh), 0, descending=False)

    prec, ind_prec = torch.sort(torch.from_numpy(prec), 0, descending=False)
    val = None
    return val
rnn.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def forward(self, inputs, lengths=None, start_state=None):
        if not self._start_state_given:
            batch_size = inputs.size(0)
            start_hidden = self._lstm_start_hidden.unsqueeze(1).expand(2, batch_size, self._size).contiguous()
            start_state = self._lstm_start_state.unsqueeze(1).expand(2, batch_size, self._size).contiguous()
            start_state = (start_hidden, start_state)

        if lengths is not None:
            new_lengths, indices = torch.sort(lengths, dim=0, descending=True)
            inputs = torch.index_select(inputs, 0, indices)
            if self._start_state_given:
                start_state = (torch.index_select(start_state[0], 1, indices),
                               torch.index_select(start_state[1], 1, indices))
            new_lengths = [l.data[0] for l in new_lengths]
            inputs = nn.utils.rnn.pack_padded_sequence(inputs, new_lengths, batch_first=True)

        output, (h_n, c_n) = self._bilstm(inputs, start_state)

        if lengths is not None:
            output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]
            _, back_indices = torch.sort(indices, dim=0)
            output = torch.index_select(output, 0, back_indices)
            h_n = torch.index_select(h_n, 1, back_indices)
            c_n = torch.index_select(c_n, 1, back_indices)

        return output, (h_n, c_n)
segment.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 52 收藏 0 点赞 0 评论 0
def segment_max(inputs, segment_ids, num_segments=None, default=0.0):
    # highly optimized to decrease the amount of actual invocation of pytorch calls
    # assumes that most segments have 1 or 0 elements
    segment_ids, indices = torch.sort(segment_ids)
    inputs = torch.index_select(inputs, 0, indices)
    output = SegmentMax.apply(inputs, segment_ids, num_segments, default)
    return output
vectorize.py 文件源码 项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码 阅读 46 收藏 0 点赞 0 评论 0
def prepare_batch(self, x, x_lens):
        x_lens, x_idx = torch.sort(x_lens, 0, True)
        _, x_ridx = torch.sort(x_idx)
        x = x[x_idx]

        x_var = Variable(x, volatile=True)
        x_lens = Variable(x_lens, volatile=True)
        x_ridx = Variable(x_ridx.long(), volatile=True)

        if self.is_cuda:
            x_var = x_var.cuda()
            x_lens = x_lens.cuda()
            x_ridx = x_ridx.cuda()

        return x_var, x_lens, x_ridx
train.py 文件源码 项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def prepare_batch(self, batch_data, volatile=False):
        x, x_lens, ys, ys_lens = batch_data
        batch_dim = 0 if self.batch_first else 1
        context_dim = 1 if self.batch_first else 0

        x_lens, x_idx = torch.sort(x_lens, 0, True)
        _, x_ridx = torch.sort(x_idx)
        ys_lens, ys_idx = torch.sort(ys_lens, batch_dim, True)

        x_ridx_exp = x_ridx.unsqueeze(context_dim).expand_as(ys_idx)
        xys_idx = torch.gather(x_ridx_exp, batch_dim, ys_idx)

        x = x[x_idx]
        ys = torch.gather(ys, batch_dim, ys_idx.unsqueeze(-1).expand_as(ys))

        x = Variable(x, volatile=volatile)
        x_lens = Variable(x_lens, volatile=volatile)
        ys_i = Variable(ys[..., :-1], volatile=volatile).contiguous()
        ys_t = Variable(ys[..., 1:], volatile=volatile).contiguous()
        ys_lens = Variable(ys_lens - 1, volatile=volatile)
        xys_idx = Variable(xys_idx, volatile=volatile)

        if self.is_cuda:
            x = x.cuda(async=True)
            x_lens = x_lens.cuda(async=True)
            ys_i = ys_i.cuda(async=True)
            ys_t = ys_t.cuda(async=True)
            ys_lens = ys_lens.cuda(async=True)
            xys_idx = xys_idx.cuda(async=True)

        return x, x_lens, ys_i, ys_t, ys_lens, xys_idx
Beam.py 文件源码 项目:attention-is-all-you-need-pytorch 作者: jadore801120 项目源码 文件源码 阅读 47 收藏 0 点赞 0 评论 0
def advance(self, word_lk):
        "Update the status and check for finished or not."
        num_words = word_lk.size(1)

        # Sum the previous scores.
        if len(self.prev_ks) > 0:
            beam_lk = word_lk + self.scores.unsqueeze(1).expand_as(word_lk)
        else:
            beam_lk = word_lk[0]

        flat_beam_lk = beam_lk.view(-1)

        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort

        self.all_scores.append(self.scores)
        self.scores = best_scores

        # bestScoresId is flattened beam x word array, so calculate which
        # word and beam each score came from
        prev_k = best_scores_id / num_words
        self.prev_ks.append(prev_k)
        self.next_ys.append(best_scores_id - prev_k * num_words)

        # End condition is when top-of-beam is EOS.
        if self.next_ys[-1][0] == Constants.EOS:
            self.done = True
            self.all_scores.append(self.scores)

        return self.done


问题


面经


文章

微信
公众号

扫码关注公众号