rntn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:RNTN 作者: munikarmanish 项目源码 文件源码
def forward_prop(self, tree):
        cost = 0.0
        result = np.zeros((5,5))

        if tr.isleaf(tree):
            # output = word vector
            try:
                tree.vector = self.L[:, self.word_map[tree[0]]]
            except:
                tree.vector = self.L[:, self.word_map[tr.UNK]]
            tree.fprop = True
        else:
            # calculate output of child nodes
            lcost, lresult = self.forward_prop(tree[0])
            rcost, rresult = self.forward_prop(tree[1])
            cost += lcost + rcost
            result += lresult + rresult

            # compute output
            lr = np.hstack([tree[0].vector, tree[1].vector])
            tree.vector = np.tanh(
                np.tensordot(self.V, np.outer(lr, lr), axes=([1, 2], [0, 1])) +
                np.dot(self.W, lr) + self.b)

        # softmax
        tree.output = np.dot(self.Ws, tree.vector) + self.bs
        tree.output -= np.max(tree.output)
        tree.output = np.exp(tree.output)
        tree.output /= np.sum(tree.output)

        tree.frop = True

        # cost
        cost -= np.log(tree.output[int(tree.label())])
        true_label = int(tree.label())
        predicted_label = np.argmax(tree.output)
        result[true_label, predicted_label] += 1

        return cost, result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号