dern.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:der-network 作者: soskek 项目源码 文件源码
def solve(self, docD, train=True):
        old2newD, e2sD = self.initialize_entities(docD["entities"], self.args.max_ent_id, train=train)
        e2dLD = dict((e, [s]) for (e, s) in e2sD.items())
        sentences = self.reload_sentences(docD["sentences"], old2newD)

        for sent in sentences:
            i2sD = OrderedDict()
            e2iLD = defaultdict(list)
            for i, token in enumerate(sent):
                if token in e2sD:
                    i2sD[i] = e2sD[token]
                    e2iLD[token].append(i)
            if not i2sD:  # skip sentences without any entities
                continue
            e2iLD = OrderedDict(e2iLD)

            concat_h_L = self.encode_context(sent, i2sD, e2iLD, train=train)
            for e, concat_h in zip(e2iLD.keys(), concat_h_L):
                e2dLD[e].append(F.tanh(self.W_hd(concat_h)))
                e2sD[e] = F.max(F.concat([e2sD[e], e2dLD[e][-1]], axis=0), axis=0, keepdims=True)

        EPS = sys.float_info.epsilon
        accum_loss_doc, TorFs, subTorFs = 0, 0, 0

        for query, answer in zip(docD["queries"], docD["answers"]):
            query = self.reload_sentence(query, old2newD)
            answer = old2newD[int(answer)]
            i2sD = dict([(i, e2sD[token]) for i, token in enumerate(query) if token in e2sD])
            u_Dq, q = self.encode_query(query, i2sD, train=train)
            eL, sL = zip(*list(e2sD.items()))
            pre_vL = [self.attention_history(e2dLD[e], q, train=train) for e in eL]
            v_eDq = self.W_dv(F.concat(pre_vL, axis=0))
            answer_idx = eL.index(answer)

            p = self.predict_answer(u_Dq, v_eDq, [True if token in query else False for token in eL], train=train) + EPS
            t = chainer.Variable(self.xp.array([answer_idx]).astype(np.int32), volatile=not train)
            accum_loss_doc += F.softmax_cross_entropy(p, t)

            p_data = p.data[0, :]
            max_idx = self.xp.argmax(p_data)
            TorFs += (max_idx == answer_idx)
            if max_idx != answer_idx:
                for sub_ans in [k for k, e in enumerate(eL) if e in query]:
                    p_data[sub_ans] = -10000000
                subTorFs += (self.xp.argmax(p_data) == answer_idx)

        return accum_loss_doc, TorFs, subTorFs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号