def f(self, net, dx, dg): # Note: this is currently not working that well. we might need a second sample of X return tf.norm(net - dg, axis=1) - tf.norm(dx, axis=1)