def set_sig(self, X, Y): Y_pred = self.lin(X) + self.net(X) var = torch.mean((Y_pred-Y)**2, 0) self.sig.data = torch.sqrt(var).cuda().data