MMTransE.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:MTransE 作者: muhaochen 项目源码 文件源码
def gradient_decent(self,tr,v_e1,v_e2,const_decay, L1=False):
        # f = || v_e1 tr - v_e2 ||_2^2

        assert len(v_e1.shape) == 1
        assert v_e1.shape == v_e2.shape
        assert tr.shape == (v_e1.shape[0], v_e2.shape[0])

        f_res = np.dot(v_e1, tr) - v_e2
        d_v_e1 = 2.0 * np.dot(tr, f_res)
        d_v_e2 = - 2.0 * f_res
        d_tr = 2.0 * np.dot(v_e1[:, np.newaxis], f_res[np.newaxis, :])

        v_e1 -= d_v_e1 * self.rate
        v_e2 -= d_v_e2 * self.rate
        tr -= d_tr * self.rate

        v_e1 /= LA.norm(v_e1)
        v_e2 /= LA.norm(v_e2)
        # don't touch tr

        return LA.norm(np.dot(v_e1, tr) - v_e2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号