def reset_parameters(self):
if hasattr(self, 'sigma_weight'): # Only init after all params added (otherwise super().__init__() fails)
init.uniform(self.weight, -math.sqrt(3 / self.in_features), math.sqrt(3 / self.in_features))
init.uniform(self.bias, -math.sqrt(3 / self.in_features), math.sqrt(3 / self.in_features))
init.constant(self.sigma_weight, self.sigma_init)
init.constant(self.sigma_bias, self.sigma_init)
评论列表
文章目录