rntn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:RNTN 作者: munikarmanish 项目源码 文件源码
def back_prop(self, tree, error=None):
        # clear nodes
        tree.frop = False

        # softmax grad
        deltas = tree.output
        deltas[int(tree.label())] -= 1.0
        self.dWs += np.outer(deltas, tree.vector)
        self.dbs += deltas
        deltas = np.dot(self.Ws.T, deltas)
        if error is not None:
            deltas += error
        deltas *= (1 - tree.vector**2)

        # leaf node => update word vectors
        if tr.isleaf(tree):
            try:
                index = self.word_map[tree[0]]
            except KeyError:
                index = self.word_map[tr.UNK]
            self.dL[index] += deltas
            return

        # Hidden gradients
        else:
            lr = np.hstack([tree[0].vector, tree[1].vector])
            outer = np.outer(deltas, lr)
            self.dV += (np.outer(lr, lr)[..., None] * deltas).T
            self.dW += outer
            self.db += deltas

            # Compute error for children
            deltas = np.dot(self.W.T, deltas)
            deltas += np.tensordot(self.V.transpose((0,2,1)) + self.V, outer.T,
                                   axes=([1,0], [0,1]))

            self.back_prop(tree[0], deltas[:self.dim])
            self.back_prop(tree[1], deltas[self.dim:])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号