def forward(self, _input):
"""
the forward method that does the masked linear computation and returns the result
"""
masked_weight = self.weight * torch.autograd.Variable(self.mask)
return F.linear(_input, masked_weight, self.bias)
评论列表
文章目录