def reset_params(self): layers.set_all_param_values(self.treatment_output, self.init_treatment_params) layers.set_all_param_values(self.instrument_output, self.init_instrument_params)