def get_tgt_vec(self,r):
"""
Computes the target vector `g` in the above description
"""
r0,r1 = r
g0 = 1/self.rsqrt0*r0
if self.wvar_pos:
gout = 1/self.wsqrt*np.broadcast_to(self.b,self.shape1)
g1 = 1/self.rsqrt1*r1
g = np.hstack((gout.ravel(),g0.ravel(),g1.ravel()))
else:
g1 = 1/self.rsqrt1*(r1-self.b)
g = np.hstack((g0.ravel(),g1.ravel()))
return g
评论列表
文章目录