def reset_parameters(self): init.kaiming_normal(self.comp_linear.weight.data) init.constant(self.comp_linear.bias.data, val=0)