python类gather()的实例源码

model.py 文件源码 项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码 阅读 71 收藏 0 点赞 0 评论 0
def transition_score(self, labels, lens):
        """
        Arguments:
             labels: [batch_size, seq_len] LongTensor
             lens: [batch_size] LongTensor
        """
        batch_size, seq_len = labels.size()

        # pad labels with <start> and <stop> indices
        labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
        labels_ext[:, 0] = self.start_idx
        labels_ext[:, 1:-1] = labels
        mask = sequence_mask(lens + 1, max_len=seq_len + 2).long()
        pad_stop = Variable(labels.data.new(1).fill_(self.stop_idx))
        pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
        labels_ext = (1 - mask) * pad_stop + mask * labels_ext
        labels = labels_ext

        trn = self.transitions

        # obtain transition vector for each label in batch and timestep
        # (except the last ones)
        trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
        lbl_r = labels[:, 1:]
        lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
        trn_row = torch.gather(trn_exp, 1, lbl_rexp)

        # obtain transition score from the transition vector for each label
        # in batch and timestep (except the first ones)
        lbl_lexp = labels[:, :-1].unsqueeze(-1)
        trn_scr = torch.gather(trn_row, 2, lbl_lexp)
        trn_scr = trn_scr.squeeze(-1)

        mask = sequence_mask(lens + 1).float()
        trn_scr = trn_scr * mask
        score = trn_scr.sum(1).squeeze(-1)

        return score
model.py 文件源码 项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _bilstm_score(self, logits, y, lens):
        y_exp = y.unsqueeze(-1)
        scores = torch.gather(logits, 2, y_exp).squeeze(-1)
        mask = sequence_mask(lens).float()
        scores = scores * mask
        score = scores.sum(1).squeeze(-1)

        return score
Normalize.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def updateGradInput(self, input, gradOutput):
        assert input.dim() == 2
        assert gradOutput.dim() == 2

        input_size = input.size()
        n = input.size(0) # batch size
        d = input.size(1) # dimensionality of vectors

        self._gradInput = self._gradInput or input.new()
        self.cross = self.cross or input.new()
        # compute diagonal term with gradOutput
        self._gradInput.resize_(n, d)
        if self.p == float('inf'):
                # specialization for the inf case
                torch.mul(self._gradInput, self.norm.view(n, 1,1).expand(n, d,1), gradOutput)
                self.buffer.resize_as_(input).zero_()
                self.cross.resize_(n, 1)
                torch.gather(self.cross, input, 1, self._indices)
                self.cross.div_(self.norm)
                self.buffer.scatter_(1, self._indices, self.cross)
        else:
                torch.mul(self._gradInput, self.normp.view(n, 1).expand(n, d), gradOutput)
                # small optimizations for different p
                # buffer = input*|input|^(p-2)
                # for non-even p, need to add absolute value
                if self.p % 2 != 0:
                    if self.p < 2:
                        # add eps to avoid possible division by 0
                        torch.abs(self.buffer, input).add_(self.eps).pow_(self.p-2).mul_(input)
                    else:
                        torch.abs(self.buffer, input).pow_(self.p-2).mul_(input)
                # special case for p == 2, pow(x, 0) = 1
                elif self.p == 2:
                    self.buffer.copy_(input)
                else:
                    # p is even and > 2, pow(x, p) is always positive
                    torch.pow(self.buffer, input, self.p-2).mul_(input)

        # compute cross term in two steps
        self.cross.resize_(n, 1)

        # instead of having a huge temporary matrix (b1*b2),
        #: the computations as b1*(b2*gradOutput). This avoids redundant
        # computation and also a huge buffer of size n*d^2
        self.buffer2 = self.buffer2 or input.new() # nxd
        torch.mul(self.buffer2, input, gradOutput)
        torch.sum(self.cross, self.buffer2, 1)

        self.buffer.mul_(self.cross.expand_as(self.buffer))
        self._gradInput.add_(-1, self.buffer)

        # reuse cross buffer for normalization
        if self.p == float('inf'):
            torch.mul(self.cross, self.norm, self.norm)
        else:
            torch.mul(self.cross, self.normp, self.norm)

        self._gradInput.div_(self.cross.expand(n, d))

        self.gradInput = self._gradInput.view(input_size)
        return self.gradInput
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                        random.randint(1, SIZE),
                        random.randint(1, SIZE))

        for kTries in range(3):
            for dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
util.py 文件源码 项目:allennlp 作者: allenai 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    # Contribution to the negative log likelihood only comes from the exact indices
    # of the targets, as the target distributions are one-hot. Here we use torch.gather
    # to extract the indices of the num_classes dimension which contribute to the loss.
    # shape : (batch * sequence_length, 1)
    negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
train.py 文件源码 项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def val_sents(self, data, dec_logits):
        vocab, previews = self.model.vocab, self.previews
        x, x_lens, ys_i, ys_t, ys_lens, xys_idx = data

        if self.batch_first:
            cdata = [ys_i, ys_t, ys_lens, xys_idx, dec_logits]
            cdata = [d.transpose(1, 0).contiguous() for d in cdata]
            ys_i, ys_t, ys_lens, xys_idx, dec_logits = cdata

        _, xys_ridx = torch.sort(xys_idx, 1)
        xys_ridx_exp = xys_ridx.unsqueeze(-1).expand_as(ys_i)
        ys_i = torch.gather(ys_i, 1, xys_ridx_exp)
        ys_t = torch.gather(ys_t, 1, xys_ridx_exp)
        dec_logits = [torch.index_select(logits, 0, xy_ridx)
                      for logits, xy_ridx in zip(dec_logits, xys_ridx)]
        ys_lens = torch.gather(ys_lens, 1, xys_ridx)

        x, x_lens = x[:previews], x_lens[:previews]
        ys_i, ys_t = ys_i[:, :previews], ys_t[:, :previews]
        dec_logits = torch.cat(
            [logits[:previews].max(2)[1].squeeze(-1).unsqueeze(0)
             for logits in dec_logits], 0)
        ys_lens = ys_lens[:, :previews]

        ys_i, ys_t = ys_i.transpose(1, 0), ys_t.transpose(1, 0)
        dec_logits, ys_lens = dec_logits.transpose(1, 0), ys_lens.transpose(1,
                                                                            0)

        x, x_lens = x.data.tolist(), x_lens.data.tolist()
        ys_i, ys_t = ys_i.data.tolist(), ys_t.data.tolist()
        dec_logits, ys_lens = dec_logits.data.tolist(), ys_lens.data.tolist()

        def to_sent(data, length, vocab):
            return " ".join(vocab.i2f[data[i]] for i in range(length))

        def to_sents(data, lens, vocab):
            return [to_sent(d, l, vocab) for d, l in zip(data, lens)]

        x_sents = to_sents(x, x_lens, vocab)
        yi_sents = [to_sents(yi, y_lens, vocab) for yi, y_lens in
                    zip(ys_i, ys_lens)]
        yt_sents = [to_sents(yt, y_lens, vocab) for yt, y_lens in
                    zip(ys_t, ys_lens)]
        o_sents = [to_sents(dec_logit, y_lens, vocab)
                   for dec_logit, y_lens in zip(dec_logits, ys_lens)]

        return x_sents, yi_sents, yt_sents, o_sents
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                       random.randint(1, SIZE),
                       random.randint(1, SIZE))

        for _kTries in range(3):
            for _dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                       random.randint(1, SIZE),
                       random.randint(1, SIZE))

        for _kTries in range(3):
            for _dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                       random.randint(1, SIZE),
                       random.randint(1, SIZE))

        for _kTries in range(3):
            for _dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
pointer.py 文件源码 项目:awd-lstm-lm 作者: salesforce 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def evaluate(data_source, batch_size=10, window=args.window):
    # Turn on evaluation mode which disables dropout.
    if args.model == 'QRNN': model.reset()
    model.eval()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    next_word_history = None
    pointer_history = None
    for i in range(0, data_source.size(0) - 1, args.bptt):
        if i > 0: print(i, len(data_source), math.exp(total_loss / i))
        data, targets = get_batch(data_source, i, evaluation=True, args=args)
        output, hidden, rnn_outs, _ = model(data, hidden, return_h=True)
        rnn_out = rnn_outs[-1].squeeze()
        output_flat = output.view(-1, ntokens)
        ###
        # Fill pointer history
        start_idx = len(next_word_history) if next_word_history is not None else 0
        next_word_history = torch.cat([one_hot(t.data[0], ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data[0], ntokens) for t in targets])])
        #print(next_word_history)
        pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0)
        #print(pointer_history)
        ###
        # Built-in cross entropy
        # total_loss += len(data) * criterion(output_flat, targets).data[0]
        ###
        # Manual cross entropy
        # softmax_output_flat = torch.nn.functional.softmax(output_flat)
        # soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1))
        # entropy = -torch.log(soft)
        # total_loss += len(data) * entropy.mean().data[0]
        ###
        # Pointer manual cross entropy
        loss = 0
        softmax_output_flat = torch.nn.functional.softmax(output_flat)
        for idx, vocab_loss in enumerate(softmax_output_flat):
            p = vocab_loss
            if start_idx + idx > window:
                valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx]
                valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx]
                logits = torch.mv(valid_pointer_history, rnn_out[idx])
                theta = args.theta
                ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1)
                ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze()
                lambdah = args.lambdasm
                p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss
            ###
            target_loss = p[targets[idx].data]
            loss += (-torch.log(target_loss)).data[0]
        total_loss += loss / batch_size
        ###
        hidden = repackage_hidden(hidden)
        next_word_history = next_word_history[-window:]
        pointer_history = pointer_history[-window:]
    return total_loss / len(data_source)

# Load the best saved model.
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 65 收藏 0 点赞 0 评论 0
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                       random.randint(1, SIZE),
                       random.randint(1, SIZE))

        for _kTries in range(3):
            for _dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
Merge.py 文件源码 项目:DCN 作者: alexnowakvila 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def Decoder(self, input, hidden_encoder, phis,
                input_target=None, target=None):
        feed_target = False
        if target is not None:
            feed_target = True
        # N_n is the number of elements of the scope of the n-th element
        N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
                                                      self.hidden_size)
        output = (Variable(torch.ones(self.batch_size, self.n, self.n))
                  .type(dtype))
        index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1)
        hidden = (torch.gather(hidden_encoder, 1, index)).squeeze()
        # W1xe size: (batch_size, n + 1, hidden_size)
        W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
                self.batch_size, self.hidden_size, self.hidden_size)))
        # init token
        start = (self.init_token.unsqueeze(0).expand(self.batch_size,
                 self.input_size))
        input_step = start
        for n in xrange(self.n):
            # decouple interaction between different scopes by looking at
            # subdiagonal elements of Phi
            if n > 0:
                t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
                     self.batch_size, self.hidden_size))
                index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
                         .unsqueeze(1))
                init_hidden = (torch.gather(hidden_encoder, 1, index)
                               .squeeze())
                hidden = t * hidden + (1 - t) * init_hidden
                t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
                     self.batch_size, self.input_size))
                input_step = t * input_step + (1 - t) * start
            # Compute next state
            hidden = self.decoder_cell(input_step, hidden)
            # Compute pairwise interactions
            u = self.attention(hidden, W1xe, hidden_encoder, tanh=True)
            # Normalize interactions by taking the masked softmax by phi
            attn = self.softmax_m(phis[:, n].squeeze(), u)
            if feed_target:
                # feed next step with target
                next = (target[:, n].unsqueeze(1).unsqueeze(2)
                        .expand(self.batch_size, 1, self.input_size)
                        .type(dtype_l))
                input_step = torch.gather(input_target, 1, next).squeeze()
            else:
                # blend inputs
                input_step = (torch.sum(attn.unsqueeze(2).expand(
                              self.batch_size, self. n,
                              self.input_size) * input, 1)).squeeze()
            # Update output
            output[:, n] = attn
        return output
Merge.py 文件源码 项目:DCN 作者: alexnowakvila 项目源码 文件源码 阅读 46 收藏 0 点赞 0 评论 0
def Decoder(self, input, hidden_encoder, phis,
                input_target=None, target=None):
        feed_target = False
        if target is not None:
            feed_target = True
        # N[:, n] is the number of elements of the scope of the n-th element
        N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
                                                      self.hidden_size)
        output = (Variable(torch.ones(self.batch_size, self.n, self.n + 1))
                  .type(dtype))
        index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1).detach()
        hidden = (torch.gather(hidden_encoder, 1, index + 1)).squeeze()
        # W1xe size: (batch_size, n + 1, hidden_size)
        W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
                self.batch_size, self.hidden_size, self.hidden_size)))
        # init token
        start = (self.init_token.unsqueeze(0).expand(self.batch_size,
                 self.input_size))
        input_step = start
        for n in xrange(self.n):
            # decouple interaction between different scopes by looking at
            # subdiagonal elements of Phi
            if n > 0:
                t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
                     self.batch_size, self.hidden_size))
                index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
                         .unsqueeze(1)).detach()
                init_hidden = (torch.gather(hidden_encoder, 1, index + 1)
                               .squeeze())
                hidden = t * hidden + (1 - t) * init_hidden
                t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
                     self.batch_size, self.input_size))
                input_step = t * input_step + (1 - t) * start
            # Compute next state
            hidden = self.decoder_cell(input_step, hidden)
            # Compute pairwise interactions
            u = self.attention(hidden, W1xe, hidden_encoder)
            # Normalize interactions by taking the masked softmax by phi
            pad = Variable(torch.ones(self.batch_size, 1)).type(dtype)
            mask = torch.cat((pad, phis[:, n].squeeze()), 1)
            attn = self.softmax_m(mask, u)
            if feed_target:
                # feed next step with target
                next = (target[:, n].unsqueeze(1).unsqueeze(2)
                        .expand(self.batch_size, 1, self.input_size)
                        .type(dtype_l))
                input_step = torch.gather(input_target, 1, next).squeeze()
            else:
                # not blend
                index = attn.max(1)[1].squeeze()
                next = (index.unsqueeze(1).unsqueeze(2)
                        .expand(self.batch_size, 1, self.input_size)
                        .type(dtype_l))
                input_step = torch.gather(input, 1, next).squeeze()
                # blend inputs
                # input_step = (torch.sum(attn.unsqueeze(2).expand(
                #               self.batch_size, self. n + 1,
                #               self.input_size) * input, 1)).squeeze()
            # Update output
            output[:, n] = attn
        return output
model.py 文件源码 项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def viterbi_decode(self, logits, lens):
        """Borrowed from pytorch tutorial

        Arguments:
            logits: [batch_size, seq_len, n_labels] FloatTensor
            lens: [batch_size] LongTensor
        """
        batch_size, seq_len, n_labels = logits.size()
        vit = logits.data.new(batch_size, self.n_labels).fill_(-10000)
        vit[:, self.start_idx] = 0
        vit = Variable(vit)
        c_lens = lens.clone()

        logits_t = logits.transpose(1, 0)
        pointers = []
        for logit in logits_t:
            vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
            trn_exp = self.transitions.unsqueeze(0).expand_as(vit_exp)
            vit_trn_sum = vit_exp + trn_exp
            vt_max, vt_argmax = vit_trn_sum.max(2)

            vt_max = vt_max.squeeze(-1)
            vit_nxt = vt_max + logit
            pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))

            mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
            vit = mask * vit_nxt + (1 - mask) * vit

            mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
            vit += mask * self.transitions[ self.stop_idx ].unsqueeze(0).expand_as(vit_nxt)

            c_lens = c_lens - 1

        pointers = torch.cat(pointers)
        scores, idx = vit.max(1)
        idx = idx.squeeze(-1)
        paths = [idx.unsqueeze(1)]

        for argmax in reversed(pointers):
            idx_exp = idx.unsqueeze(-1)
            idx = torch.gather(argmax, 1, idx_exp)
            idx = idx.squeeze(-1)

            paths.insert(0, idx.unsqueeze(1))

        paths = torch.cat(paths[1:], 1)
        scores = scores.squeeze(-1)

        return scores, paths


问题


面经


文章

微信
公众号

扫码关注公众号