model.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:teras 作者: chantera 项目源码 文件源码
def forward(self, input1, input2):
        is_cuda = next(self.parameters()).is_cuda
        device_id = next(self.parameters()).get_device() if is_cuda else None
        out_size = self.out_features
        batch_size, len1, dim1 = input1.size()
        if self._use_bias[0]:
            ones = torch.ones(batch_size, len1, 1)
            if is_cuda:
                ones = ones.cuda(device_id)
            input1 = torch.cat((input1, Variable(ones)), dim=2)
            dim1 += 1
        len2, dim2 = input2.size()[1:]
        if self._use_bias[1]:
            ones = torch.ones(batch_size, len2, 1)
            if is_cuda:
                ones = ones.cuda(device_id)
            input2 = torch.cat((input2, Variable(ones)), dim=2)
            dim2 += 1
        input1_reshaped = input1.contiguous().view(batch_size * len1, dim1)
        W_reshaped = torch.transpose(self.weight, 1, 2) \
            .contiguous().view(dim1, out_size * dim2)
        affine = torch.mm(input1_reshaped, W_reshaped) \
            .view(batch_size, len1 * out_size, dim2)
        biaffine = torch.transpose(
            torch.bmm(affine, torch.transpose(input2, 1, 2))
            .view(batch_size, len1, out_size, len2), 2, 3)
        if self._use_bias[2]:
            biaffine += self.bias.expand_as(biaffine)
        return biaffine
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号