python类bmm()的实例源码

models.py 文件源码 项目:Structured-Self-Attentive-Sentence-Embedding 作者: ExplorerFreda 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def forward(self, inp, hidden):
        outp = self.bilstm.forward(inp, hidden)[0]
        size = outp.size()  # [bsz, len, nhid]
        compressed_embeddings = outp.view(-1, size[2])  # [bsz*len, nhid*2]
        transformed_inp = torch.transpose(inp, 0, 1).contiguous()  # [bsz, len]
        transformed_inp = transformed_inp.view(size[0], 1, size[1])  # [bsz, 1, len]
        concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
        concatenated_inp = torch.cat(concatenated_inp, 1)  # [bsz, hop, len]

        hbar = self.tanh(self.ws1(self.drop(compressed_embeddings)))  # [bsz*len, attention-unit]
        alphas = self.ws2(hbar).view(size[0], size[1], -1)  # [bsz, len, hop]
        alphas = torch.transpose(alphas, 1, 2).contiguous()  # [bsz, hop, len]
        penalized_alphas = alphas + (
            -10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float())
            # [bsz, hop, len] + [bsz, hop, len]
        alphas = self.softmax(penalized_alphas.view(-1, size[1]))  # [bsz*hop, len]
        alphas = alphas.view(size[0], self.attention_hops, size[1])  # [bsz, hop, len]
        return torch.bmm(alphas, outp), alphas
blas.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def backward(self, grad_output):
        batch1, batch2 = self.saved_tensors
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if self.needs_input_grad[0]:
            grad_add_matrix = grad_output
            if self.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(self.alpha)

        if any(self.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                    .unsqueeze(0)
                    .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if self.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if self.beta != 1:
                grad_batch1 *= self.beta

        if self.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if self.beta != 1:
                grad_batch2 *= self.beta

        return grad_add_matrix, grad_batch1, grad_batch2
blas.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def backward(self, grad_output):
        batch1, batch2 = self.saved_tensors
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if self.needs_input_grad[0]:
            grad_add_batch = grad_output
            if self.alpha != 1:
                grad_add_batch = grad_add_batch.mul(self.alpha)

        if self.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if self.beta != 1:
                grad_batch1 *= self.beta

        if self.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if self.beta != 1:
                grad_batch2 *= self.beta

        return grad_add_batch, grad_batch1, grad_batch2
MM.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
            torch.mm(self.output, a, b)
        else:
            if self.transA:
                a = a.transpose(2, 3)
            if self.transB:
                b = b.transpose(2, 3)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(self.output, a, b)

        return self.output
MV.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
            self.output.resize_(M.size(0))
            torch.mv(self.output, M, v)
        else:
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(self.output, M, v.view(v.size(0), v.size(1), 1)).resize_(M.size(0), M.size(1))

        return self.output
layers.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def forward(self, q, k, v, attn_mask):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        residual = q

        bsz, len_q, d_model = q.size()
        len_k, len_v = k.size(1), v.size(1)

        def reshape(x):
            """[bsz, len, d_*] -> [n_head x (bsz*len) x d_*]"""
            return x.repeat(n_head, 1, 1).view(n_head, -1, d_model)

        q_s, k_s, v_s = map(reshape, [q, k, v])

        q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)
        k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)
        v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)

        outputs = self.attention(q_s, k_s, v_s, attn_mask.repeat(n_head, 1, 1))
        outputs = torch.cat(torch.split(outputs, bsz, dim=0), dim=-1).view(-1, n_head*d_v)
        outputs = F.dropout(self.w_o(outputs), p=self.dropout).view(bsz, len_q, -1)
        return self.lm(outputs + residual)
fconv.py 文件源码 项目:ParlAI 作者: facebookresearch 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def forward(self, x, target_embedding, encoder_out):
        residual = x

        # attention
        x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
        x = self.bmm(x, encoder_out[0])

        # softmax over last dim
        sz = x.size()
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
        x = x.view(sz)
        attn_scores = x

        x = self.bmm(x, encoder_out[1])

        # scale attention output
        s = encoder_out[1].size(1)
        x = x * (s * math.sqrt(1.0 / s))

        # project back
        x = (self.out_projection(x) + residual) * math.sqrt(0.5)
        return x, attn_scores
Modules.py 文件源码 项目:attention-is-all-you-need-pytorch 作者: jadore801120 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def forward(self, q, k, v, attn_mask=None):

        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper

        if attn_mask is not None:

            assert attn_mask.size() == attn.size(), \
                    'Attention mask shape {} mismatch ' \
                    'with Attention logit tensor shape ' \
                    '{}.'.format(attn_mask.size(), attn.size())

            attn.data.masked_fill_(attn_mask, -float('inf'))

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn
GlobalAttention.py 文件源码 项目:bandit-nmt 作者: khanhptnk 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def forward(self, inputs, context):
        """
        inputs: batch x dim
        context: batch x sourceL x dim
        """
        targetT = self.linear_in(inputs).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, targetT).squeeze(2)  # batch x sourceL
        if self.mask is not None:
            attn.data.masked_fill_(self.mask, -_INF)
        attn = self.sm(attn)
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weightedContext = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        contextCombined = torch.cat((weightedContext, inputs), 1)

        contextOutput = self.tanh(self.linear_out(contextCombined))

        return contextOutput, attn
fconv.py 文件源码 项目:fairseq-py 作者: facebookresearch 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def forward(self, x, target_embedding, encoder_out):
        residual = x

        # attention
        x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
        x = self.bmm(x, encoder_out[0])

        # softmax over last dim
        sz = x.size()
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
        x = x.view(sz)
        attn_scores = x

        x = self.bmm(x, encoder_out[1])

        # scale attention output
        s = encoder_out[1].size(1)
        x = x * (s * math.sqrt(1.0 / s))

        # project back
        x = (self.out_projection(x) + residual) * math.sqrt(0.5)
        return x, attn_scores
GlobalAttention.py 文件源码 项目:NeuralMT 作者: hlt-mt 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def forward(self, input, context):
        """
        input: batch x dim
        context: batch x sourceL x dim
        """
        targetT = self.linear_in(input).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, targetT).squeeze(2)  # batch x sourceL
        if self.mask is not None:
            attn.data.masked_fill_(self.mask, -float('inf'))
        attn = self.sm(attn)
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weightedContext = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        contextCombined = torch.cat((weightedContext, input), 1)

        contextOutput = self.tanh(self.linear_out(contextCombined))

        return contextOutput, attn
GlobalAttention.py 文件源码 项目:alpha-dimt-icmlws 作者: sotetsuk 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def forward(self, input, context):
        """
        input: batch x dim
        context: batch x sourceL x dim
        """
        targetT = self.linear_in(input).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, targetT).squeeze(2)  # batch x sourceL
        if self.mask is not None:
            attn.data.masked_fill_(self.mask, -float('inf'))
        attn = self.sm(attn)
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weightedContext = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        contextCombined = torch.cat((weightedContext, input), 1)

        contextOutput = self.tanh(self.linear_out(contextCombined))

        return contextOutput, attn
pointnet.py 文件源码 项目:pointnet2.pytorch 作者: eriche2016 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def forward(self, x):
        batchsize = x.size()[0]
        trans = self.stn(x) # regressing the transforming parameters using STN 
        x = x.transpose(2,1) # bz x 2048 x 3 
        x = torch.bmm(x, trans) # (bz x 2048 x 3) x (bz x 3 x 3) 
        x = x.transpose(2,1) # bz x 3 x 2048
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x # bz x 64 x 2048
        x = F.relu(self.bn2(self.conv2(x))) # bz x 128 x 2048
        x = self.bn3(self.conv3(x)) # bz x 1024 x 2048
        x = self.mp1(x)
        x = x.view(-1, 1024) # bz x 1024
        if self.global_feat: # using global feats for classification
            return x, trans
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
            return torch.cat([x, pointfeat], 1), trans
blas.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_matrix = grad_output
            if ctx.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(ctx.alpha)

        if any(ctx.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                                 .unsqueeze(0)
                                 .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
blas.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_batch = grad_output
            if ctx.alpha != 1:
                grad_add_batch = grad_add_batch.mul(ctx.alpha)

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_batch, grad_batch1, grad_batch2, None, None, None
MM.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
            torch.mm(a, b, out=self.output)
        else:
            if self.transA:
                a = a.transpose(2, 3)
            if self.transB:
                b = b.transpose(2, 3)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output
MV.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
            self.output.resize_(M.size(0))
            torch.mv(M, v, out=self.output)
        else:
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))

        return self.output
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def _test_btrisolve(self, cast):
        a = torch.FloatTensor((((1.3722, -0.9020),
                                (1.8849, 1.9169)),
                               ((0.7187, -1.1695),
                                (-0.0139, 1.3572)),
                               ((-1.6181, 0.7148),
                                (1.3728, 0.1319))))
        b = torch.FloatTensor(((4.02, 6.19),
                               (-1.56, 4.00),
                               (9.81, -4.09)))
        a, b = cast(a), cast(b)
        info = cast(torch.IntTensor())
        LU_data, pivots = a.btrifact(info=info)
        self.assertEqual(info.abs().sum(), 0)
        x = torch.btrisolve(b, LU_data, pivots)
        b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
        self.assertEqual(b_, b)
gridgen.py 文件源码 项目:lr-gan.pytorch 作者: jwyang 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def forward(self, input1):
        self.input1 = input1
        output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
        self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
        for i in range(input1.size(0)):
            self.batchgrid[i] = self.grid

        if input1.is_cuda:
            self.batchgrid = self.batchgrid.cuda()
            output = output.cuda()

        batchgrid_temp = self.batchgrid.view(-1, self.height*self.width, 3)
        batchgrid_temp.contiguous()
        input_temp = torch.transpose(input1, 1, 2)
        input_temp.contiguous()
        output_temp = torch.bmm(batchgrid_temp, input_temp)
        output = output_temp.view(-1, self.height, self.width, 2)
        output.contiguous()
        return output
model.py 文件源码 项目:Seq2Seq-PyTorch 作者: MaximumEntropy 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def forward(self, input, context):
        """Propogate input through the network.

        input: batch x dim
        context: batch x sourceL x dim
        """
        target = self.linear_in(input).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, target).squeeze(2)  # batch x sourceL
        attn = self.sm(attn)
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weighted_context = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        h_tilde = torch.cat((weighted_context, input), 1)

        h_tilde = self.tanh(self.linear_out(h_tilde))

        return h_tilde, attn
blas.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_matrix = grad_output
            if ctx.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(ctx.alpha)

        if any(ctx.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                                 .unsqueeze(0)
                                 .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
blas.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_batch = grad_output
            if ctx.alpha != 1:
                grad_add_batch = grad_add_batch.mul(ctx.alpha)

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_batch, grad_batch1, grad_batch2, None, None, None
MM.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
            torch.mm(a, b, out=self.output)
        else:
            if self.transA:
                a = a.transpose(1, 2)
            if self.transB:
                b = b.transpose(1, 2)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output
MV.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
            self.output.resize_(M.size(0))
            torch.mv(M, v, out=self.output)
        else:
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))

        return self.output
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _test_btrisolve(self, cast):
        a = torch.FloatTensor((((1.3722, -0.9020),
                                (1.8849, 1.9169)),
                               ((0.7187, -1.1695),
                                (-0.0139, 1.3572)),
                               ((-1.6181, 0.7148),
                                (1.3728, 0.1319))))
        b = torch.FloatTensor(((4.02, 6.19),
                               (-1.56, 4.00),
                               (9.81, -4.09)))
        a, b = cast(a), cast(b)
        info = cast(torch.IntTensor())
        LU_data, pivots = a.btrifact(info=info)
        self.assertEqual(info.abs().sum(), 0)
        x = torch.btrisolve(b, LU_data, pivots)
        b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
        self.assertEqual(b_, b)
pointnet.py 文件源码 项目:pointnet.pytorch 作者: fxia22 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def forward(self, x):
        batchsize = x.size()[0]
        trans = self.stn(x)
        x = x.transpose(2,1)
        x = torch.bmm(x, trans)
        x = x.transpose(2,1)
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = self.mp1(x)
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
            return torch.cat([x, pointfeat], 1), trans
attention.py 文件源码 项目:seq2seq.pytorch 作者: eladhoffer 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def calc_score(self, att_query, att_keys):
        """
        att_query is: b x t_q x n
        att_keys is b x t_k x n
        return b x t_q x t_k scores
        """

        b, t_k, n = list(att_keys.size())
        t_q = att_query.size(1)
        if self.mode == 'bahdanau':
            att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
            att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
            sum_qk = att_query + att_keys
            sum_qk = sum_qk.view(b * t_k * t_q, n)
            out = self.linear_att(F.tanh(sum_qk)).view(b, t_q, t_k)
        elif self.mode == 'dot_prod':
            out = torch.bmm(att_query, att_keys.transpose(1, 2))
            if self.normalize:
                out.div_(n ** 0.5)
        return out
VAE_HF.py 文件源码 项目:vae_vpflows 作者: jmtomczak 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def forward(self, v, z):
        '''
        :param v: batch_size (B) x latent_size (L)
        :param z: batch_size (B) x latent_size (L)
        :return: z_new = z - 2* v v_T / norm(v,2) * z
        '''
        # v * v_T
        vvT = torch.bmm( v.unsqueeze(2), v.unsqueeze(1) )  # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L
        # v * v_T * z
        vvTz = torch.bmm( vvT, z.unsqueeze(2) ).squeeze(2) # A * z : batchdot( B x L x L * B x L x 1 ).squeeze(2) = (B x L x 1).squeeze(2) = B x L
        # calculate norm ||v||^2
        norm_sq = torch.sum( v * v, 1 ) # calculate norm-2 for each row : B x 1
        norm_sq = norm_sq.expand( norm_sq.size(0), v.size(1) ) # expand sizes : B x L
        # calculate new z
        z_new = z - 2 * vvTz / norm_sq # z - 2 * v * v_T  * z / norm2(v)
        return z_new
VAE_ccLinIAF.py 文件源码 项目:vae_vpflows 作者: jmtomczak 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def forward(self, L, z):
        '''
        :param L: batch_size (B) x latent_size^2 (L^2)
        :param z: batch_size (B) x latent_size (L)
        :return: z_new = L*z
        '''
        # L->tril(L)
        L_matrix = L.view( -1, self.args.z1_size, self.args.z1_size ) # resize to get B x L x L
        LTmask = torch.tril( torch.ones(self.args.z1_size, self.args.z1_size), k=-1 ) # lower-triangular mask matrix (1s in lower triangular part)
        I = Variable( torch.eye(self.args.z1_size, self.args.z1_size).expand(L_matrix.size(0), self.args.z1_size, self.args.z1_size) )
        if self.args.cuda:
            LTmask = LTmask.cuda()
            I = I.cuda()
        LTmask = Variable(LTmask)
        LTmask = LTmask.unsqueeze(0).expand( L_matrix.size(0), self.args.z1_size, self.args.z1_size ) # 1 x L x L -> B x L x L
        LT = torch.mul( L_matrix, LTmask ) + I # here we get a batch of lower-triangular matrices with ones on diagonal

        # z_new = L * z
        z_new = torch.bmm( LT , z.unsqueeze(2) ).squeeze(2) # B x L x L * B x L x 1 -> B x L

        return z_new
vision.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def forward(ctx, theta, size):
        assert type(size) == torch.Size
        N, C, H, W = size
        ctx.size = size
        if theta.is_cuda:
            ctx.is_cuda = True
            AffineGridGenerator._enforce_cudnn(theta)
            grid = theta.new(N, H, W, 2)
            theta = theta.contiguous()
            torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W)
        else:
            ctx.is_cuda = False
            base_grid = theta.new(N, H, W, 3)
            linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
            base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
            linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
            base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
            base_grid[:, :, :, 2] = 1
            ctx.base_grid = base_grid
            grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
            grid = grid.view(N, H, W, 2)
        return grid


问题


面经


文章

微信
公众号

扫码关注公众号