python类index_select()的实例源码

sparse.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def forward(self, indices, weight):
        assert indices.dim() <= 2
        assert not self.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"
        self._backend = type2backend[type(weight)]
        self._weight_size = weight.size()

        if not indices.is_contiguous():
            self._indices = indices.contiguous()
            indices = self._indices
        else:
            self.save_for_backward(indices)

        if self.max_norm is not None:
            self._renorm(indices, weight)

        output = weight.new()
        if indices.dim() == 1:
            torch.index_select(output, weight, 0, indices)
        else:
            torch.index_select(output, weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
air.py 文件源码 项目:pyro 作者: uber 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def expand_z_where(z_where):
    # Take a batch of three-vectors, and massages them into a batch of
    # 2x3 matrices with elements like so:
    # [s,x,y] -> [[s,0,x],
    #             [0,s,y]]
    n = z_where.size(0)
    out = torch.cat((ng_zeros([1, 1]).type_as(z_where).expand(n, 1), z_where), 1)
    ix = Variable(expansion_indices)
    if z_where.is_cuda:
        ix = ix.cuda()
    out = torch.index_select(out, 1, ix)
    out = out.view(n, 2, 3)
    return out


# Scaling by `1/scale` here is unsatisfactory, as `scale` could be
# zero.
test_sampling.py 文件源码 项目:pyro 作者: uber 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def setUp(self):

        # simple Gaussian-emission HMM
        def model():
            p_latent = pyro.param("p1", Variable(torch.Tensor([[0.7], [0.3]])))
            p_obs = pyro.param("p2", Variable(torch.Tensor([[0.9], [0.1]])))

            latents = [Variable(torch.ones(1, 1))]
            observes = []
            for t in range(self.model_steps):

                latents.append(
                    pyro.sample("latent_{}".format(str(t)),
                                Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long()))))

                observes.append(
                    pyro.observe("observe_{}".format(str(t)),
                                 Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())),
                                 self.data[t]))
            return torch.sum(torch.cat(latents))

        self.model_steps = 3
        self.data = [pyro.ones(1, 1) for _ in range(self.model_steps)]
        self.model = model
segment.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def backward(ctx, grad_outputs):
        size = grad_outputs.size(1)
        segm_sorted = torch.sort(ctx.rev_segm_sorted)[1]
        grad_outputs = torch.index_select(grad_outputs, 0, segm_sorted)

        offset = [ctx.num_zeros]

        def backward_segment(l, n):
            segment_grad = grad_outputs.narrow(0, offset[0], n // l)
            if l > 1:
                segment_grad = _MyMax.backward(ctx.maxes[l], segment_grad)[0].view(n, size)
            offset[0] += n // l
            return segment_grad

        segment_grads = [backward_segment(l, n) for l, n in enumerate(ctx.num_lengths) if n > 0]
        grads = torch.cat(segment_grads, 0)
        rev_length_sorted = torch.sort(ctx.lengths_sorted)[1]
        grads = torch.index_select(grads, 0, rev_length_sorted)

        return grads, None, None, None
vectorize.py 文件源码 项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def vectorize(self, x, x_lens):
        x, x_lens, r_idx = self.prepare_batch(x, x_lens)

        if len(x.size()) == 2:
            vectors = self.model.encode(x, x_lens)
        elif len(x.size()) == 3:
            vectors = self.model.encode_embed(x, x_lens)
        else:
            raise Exception()

        vectors = torch.index_select(vectors, 0, r_idx)

        if self.is_cuda:
            vectors = vectors.cpu()

        return vectors.data
faster_rcnn.py 文件源码 项目:faster_rcnn_pytorch 作者: longcw 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data):
        # classification loss
        rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_label = rpn_data[0].view(-1)

        rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
        rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
        rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

        fg_cnt = torch.sum(rpn_label.data.ne(0))

        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)

        # box loss
        rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
        rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
        rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)

        rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)

        return rpn_cross_entropy, rpn_loss_box
mps_base.py 文件源码 项目:MSDN 作者: yikang-li 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def prepare_message(self, target_features, source_features, select_mat, gate_module):
        feature_data = []

        transfer_list = np.where(select_mat > 0)
        source_indices = Variable(torch.from_numpy(transfer_list[1]).type(torch.LongTensor)).cuda()
        target_indices = Variable(torch.from_numpy(transfer_list[0]).type(torch.LongTensor)).cuda()
        source_f = torch.index_select(source_features, 0, source_indices)
        target_f = torch.index_select(target_features, 0, target_indices)
        transferred_features = gate_module(target_f, source_f)

        for f_id in range(target_features.size()[0]):
            if len(np.where(select_mat[f_id, :] > 0)[0]) > 0:
                feature_indices = np.where(transfer_list[0] == f_id)[0]
                indices = Variable(torch.from_numpy(feature_indices).type(torch.LongTensor)).cuda()
                features = torch.index_select(transferred_features, 0, indices).mean(0).view(-1)
                feature_data.append(features)
            else:
                temp = Variable(torch.zeros(target_features.size()[1:]), requires_grad=True).type(torch.FloatTensor).cuda()
                feature_data.append(temp)
        return torch.stack(feature_data, 0)
sparse.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def forward(self, indices, weight):
        assert indices.dim() <= 2
        assert not self.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"

        self._backend = type2backend[type(weight)]
        self._weight_size = weight.size()

        if not indices.is_contiguous():
            self._indices = indices.contiguous()
            indices = self._indices
        else:
            self.save_for_backward(indices)

        output = weight.new()
        if self.max_norm is not None:
            self._renorm(indices, weight)

        if indices.dim() == 1:
            output = torch.index_select(weight, 0, indices)
        else:
            output = torch.index_select(weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
sparse.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def forward(self, indices, weight):
        assert indices.dim() <= 2
        assert not self.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"

        self._backend = type2backend[type(weight)]
        self._weight_size = weight.size()

        if not indices.is_contiguous():
            self._indices = indices.contiguous()
            indices = self._indices
        else:
            self.save_for_backward(indices)

        output = weight.new()
        if self.max_norm is not None:
            self._renorm(indices, weight)

        if indices.dim() == 1:
            output = torch.index_select(weight, 0, indices)
        else:
            output = torch.index_select(weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
__init__.py 文件源码 项目:gpytorch 作者: jrg365 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def left_interp(interp_indices, interp_values, rhs):
    is_vector = rhs.ndimension() == 1

    if is_vector:
        res = rhs.index_select(0, interp_indices.view(-1)).view(*interp_values.size())
        res = res.mul(interp_values)
        return res.sum(-1)

    else:
        interp_size = list(interp_indices.size()) + [rhs.size(-1)]
        rhs_size = deepcopy(interp_size)
        rhs_size[-3] = rhs.size()[-2]
        interp_indices_expanded = interp_indices.unsqueeze(-1).expand(*interp_size)
        res = rhs.unsqueeze(-2).expand(*rhs_size).gather(-3, interp_indices_expanded)
        res = res.mul(interp_values.unsqueeze(-1).expand(interp_size))
        return res.sum(-2)
faster_rcnn.py 文件源码 项目:pytorch_RFCN 作者: PureDiors 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data):
        # classification loss
        rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_label = rpn_data[0].view(-1)

        rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
        rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
        rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

        fg_cnt = torch.sum(rpn_label.data.ne(0))

        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)

        # box loss
        rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
        rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
        rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)

        rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)

        return rpn_cross_entropy, rpn_loss_box
rfcn.py 文件源码 项目:pytorch_RFCN 作者: PureDiors 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data):
        # classification loss
        rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_label = rpn_data[0].view(-1)

        rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
        rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
        rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

        fg_cnt = torch.sum(rpn_label.data.ne(0))

        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)

        # box loss
        rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
        rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
        rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)

        rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)

        return rpn_cross_entropy, rpn_loss_box
train_batch.py 文件源码 项目:kdnet.pytorch 作者: fxia22 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0]
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
train.py 文件源码 项目:kdnet.pytorch 作者: fxia22 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0, keepdim = True)[0] - point_set.min(dim=0, keepdim = True)[0]
    dim = torch.max(diff, dim = 1, keepdim = True)[1][0,0]
    cut = torch.median(point_set[:,dim], keepdim = True)[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
train_MG2.py 文件源码 项目:kdnet.pytorch 作者: fxia22 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0]
    diff = diff[:3]
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
test.py 文件源码 项目:kdnet.pytorch 作者: fxia22 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0] 
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]  
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
faster_rcnn.py 文件源码 项目:intel-cervical-cancer 作者: wangg12 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data):
        # classification loss
        rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_label = rpn_data[0].view(-1)

        rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
        rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
        rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

        fg_cnt = torch.sum(rpn_label.data.ne(0))

        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)

        # box loss
        rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
        rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
        rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)

        rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)

        return rpn_cross_entropy, rpn_loss_box
matrix.py 文件源码 项目:paysage 作者: drckf 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def index_select(mat: T.Tensor, index: T.LongTensor, dim: int = 0) -> T.Tensor:
    """
    Select the specified indices of a tensor along dimension dim.
    For example, dim = 1 is equivalent to mat[:, index] in numpy.

    Args:
        mat (tensor (num_samples, num_units))
        index (tensor; 1 -dimensional)
        dim (int)

    Returns:
        if dim == 0:
            mat[index, :]
        if dim == 1:
            mat[:, index]

    """
    return torch.index_select(mat, dim, index)
train.py 文件源码 项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def prepare_val_texts(model, batch_xs, batch_y, batch_lens,
                     logits, preds, n_previews):
    idx = np.random.permutation(np.arange(batch_lens.size(0)))[:n_previews]
    idx_v = Variable(torch.LongTensor(idx), volatile=True)

    if model.is_cuda:
        idx_v = idx_v.cuda()

    logits = torch.index_select(logits, 0, idx_v)
    bilstm_preds = logits.cpu().max(2)[1].squeeze(-1).data.numpy()
    crf_preds = preds.cpu().data.numpy()[idx]
    xs = batch_xs.cpu().data.numpy()[:, idx]
    y = batch_y.cpu().data.numpy()[idx]
    lens = batch_lens.cpu().data.numpy()[idx]

    sents = val_sents(model.word_vocabs, model.label_vocab,
                      xs, y, bilstm_preds, crf_preds, lens)
    texts = val_texts(*sents)

    return texts
main.py 文件源码 项目:SGAN 作者: YuhangSong 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def get_batch(self):
        indexs = self.indexs_selector.random_(0,self.dataset.size()[0]).cuda()
        return torch.index_select(self.dataset,0,indexs)
main.py 文件源码 项目:SGAN 作者: YuhangSong 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def get_batch(self):
        indexs = self.indexs_selector.random_(0,self.dataset.size()[0]).cuda()
        return torch.index_select(self.dataset,0,indexs)
LookupTable.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        self.renorm(input)
        input = self._makeInputContiguous(input)
        if input.dim() == 1:
           torch.index_select(self.output, self.weight, 0, input)
        elif input.dim() == 2:
           torch.index_select(self.output, self.weight, 0, input.view(-1))
           self.output = self.output.view(input.size(0), input.size(1), self.weight.size(1))
        else:
           raise RuntimeError("input must be a vector or matrix")

        return self.output
Index.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
         t = input[0]
         index = input[1]
         torch.index_select(self.output, t, self.dimension, index)
         return self.output
PartialLinear.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        self.output.set_(self.network.forward([input, self.partition]))
        if self.bias:
            self.output.add_(torch.index_select(self.bias, 1, self.partition).expand_as(self.output))
            self.addBuffer = self.addBuffer or input.new()
            if self.addBuffer.nelement() != input.size(0):
                self.addBuffer.resize_(input.size(0)).fill_(1)

        return self.output
polyphonic_data_loader.py 文件源码 项目:pyro 作者: uber 项目源码 文件源码 阅读 81 收藏 0 点赞 0 评论 0
def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = ng_zeros(mini_batch.size(), type_as=mini_batch.data)
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = Variable(torch.cuda.LongTensor(time_slice)) if 'cuda' in mini_batch.data.type() \
            else Variable(torch.LongTensor(time_slice))
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch


# this function takes the hidden state as output by the PyTorch rnn and
# unpacks it it; it also reverses each sequence temporally
util.py 文件源码 项目:pyro 作者: uber 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def log_sum_exp(vecs):
    n = len(vecs.size())
    if n == 1:
        vecs = vecs.view(1, -1)
    _, idx = torch.max(vecs, 1)
    max_score = torch.index_select(vecs, 1, idx.view(-1))
    ret = max_score + torch.log(torch.sum(torch.exp(vecs - max_score.expand_as(vecs))))
    if n == 1:
        return ret.view(-1)
    return ret
rnn.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def forward(self, inputs, lengths=None, start_state=None):
        if not self._start_state_given:
            batch_size = inputs.size(0)
            start_hidden = self._lstm_start_hidden.unsqueeze(1).expand(2, batch_size, self._size).contiguous()
            start_state = self._lstm_start_state.unsqueeze(1).expand(2, batch_size, self._size).contiguous()
            start_state = (start_hidden, start_state)

        if lengths is not None:
            new_lengths, indices = torch.sort(lengths, dim=0, descending=True)
            inputs = torch.index_select(inputs, 0, indices)
            if self._start_state_given:
                start_state = (torch.index_select(start_state[0], 1, indices),
                               torch.index_select(start_state[1], 1, indices))
            new_lengths = [l.data[0] for l in new_lengths]
            inputs = nn.utils.rnn.pack_padded_sequence(inputs, new_lengths, batch_first=True)

        output, (h_n, c_n) = self._bilstm(inputs, start_state)

        if lengths is not None:
            output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]
            _, back_indices = torch.sort(indices, dim=0)
            output = torch.index_select(output, 0, back_indices)
            h_n = torch.index_select(h_n, 1, back_indices)
            c_n = torch.index_select(c_n, 1, back_indices)

        return output, (h_n, c_n)
segment.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def segment_max(inputs, segment_ids, num_segments=None, default=0.0):
    # highly optimized to decrease the amount of actual invocation of pytorch calls
    # assumes that most segments have 1 or 0 elements
    segment_ids, indices = torch.sort(segment_ids)
    inputs = torch.index_select(inputs, 0, indices)
    output = SegmentMax.apply(inputs, segment_ids, num_segments, default)
    return output
xqa.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def forward(self, start_scores, end_scores, answer_span, answer_to_question):
        """very common XQA loss function."""
        long_tensor = torch.cuda.LongTensor if torch.cuda.device_count() > 0 else torch.LongTensor
        answer_span = answer_span.type(long_tensor)
        start, end = answer_span[:, 0], answer_span[:, 1]

        batch_size1 = start.data.shape[0]
        batch_size2 = start_scores.data.shape[0]
        is_aligned = batch_size1 == batch_size2

        start_scores = start_scores if is_aligned else torch.index_select(start_scores, 0, answer_to_question)
        end_scores = end_scores if is_aligned else torch.index_select(end_scores, 0, answer_to_question)

        partitioned_loss = []
        for i, j in enumerate(answer_to_question):
            j = j.data[0]
            while j >= len(partitioned_loss):
                partitioned_loss.append([])
            loss = -torch.index_select(F.log_softmax(start_scores[i]), 0, start[i])
            loss -= torch.index_select(F.log_softmax(end_scores[i]), 0, end[i])
            partitioned_loss[j].append(loss)

        for j, l in enumerate(partitioned_loss):
            partitioned_loss[j] = torch.stack(l).min()

        loss = torch.stack(partitioned_loss).mean()
        return loss
model.py 文件源码 项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def forward(self, x, x_lens, ys, ys_lens, xys_idx):
        x = self.embeddings(x)
        h = self._encode_embed(x, x_lens)

        if self.batch_first:
            ys = ys.transpose(1, 0)
            ys_lens = ys_lens.transpose(1, 0)
            xys_idx = xys_idx.transpose(1, 0)

        logits_list = []

        for dec_idx, (y, y_lens, xy_idx) in enumerate(
                zip(ys, ys_lens, xys_idx)):
            h_dec = torch.index_select(h, 0, xy_idx)
            logits = self._decode(dec_idx, h_dec, y, y_lens)

            nil_batches = len(h_dec) - len(logits)
            if nil_batches:
                logits = pad_batch(logits, nil_batches, True)

            logits_list.append(logits.unsqueeze(0))

        logits = torch.cat(logits_list)

        if self.batch_first:
            logits = logits.transpose(1, 0)

        return logits, h


问题


面经


文章

微信
公众号

扫码关注公众号