linear_two.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:vampyre 作者: GAMPTeam 项目源码 文件源码
def get_tgt_vec(self,r):
        """
        Computes the target vector `g` in the above description
        """
        r0,r1 = r
        g0 = 1/self.rsqrt0*r0
        if self.wvar_pos:
            gout = 1/self.wsqrt*np.broadcast_to(self.b,self.shape1)
            g1 = 1/self.rsqrt1*r1            
            g = np.hstack((gout.ravel(),g0.ravel(),g1.ravel()))
        else:
            g1 = 1/self.rsqrt1*(r1-self.b)
            g = np.hstack((g0.ravel(),g1.ravel()))
        return g
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号