def _modReLU(self, h, bias):
"""
sign(z)*relu(z)
"""
batch_size = h.size(0)
sign = torch.sign(h)
bias_batch = (bias.unsqueeze(0)
.expand(batch_size, *bias.size()))
return sign * functional.relu(torch.abs(h) + bias_batch)
评论列表
文章目录