def solve(self, docD, train=True):
old2newD, e2sD = self.initialize_entities(docD["entities"], self.args.max_ent_id, train=train)
e2dLD = dict((e, [s]) for (e, s) in e2sD.items())
sentences = self.reload_sentences(docD["sentences"], old2newD)
for sent in sentences:
i2sD = OrderedDict()
e2iLD = defaultdict(list)
for i, token in enumerate(sent):
if token in e2sD:
i2sD[i] = e2sD[token]
e2iLD[token].append(i)
if not i2sD: # skip sentences without any entities
continue
e2iLD = OrderedDict(e2iLD)
concat_h_L = self.encode_context(sent, i2sD, e2iLD, train=train)
for e, concat_h in zip(e2iLD.keys(), concat_h_L):
e2dLD[e].append(F.tanh(self.W_hd(concat_h)))
e2sD[e] = F.max(F.concat([e2sD[e], e2dLD[e][-1]], axis=0), axis=0, keepdims=True)
EPS = sys.float_info.epsilon
accum_loss_doc, TorFs, subTorFs = 0, 0, 0
for query, answer in zip(docD["queries"], docD["answers"]):
query = self.reload_sentence(query, old2newD)
answer = old2newD[int(answer)]
i2sD = dict([(i, e2sD[token]) for i, token in enumerate(query) if token in e2sD])
u_Dq, q = self.encode_query(query, i2sD, train=train)
eL, sL = zip(*list(e2sD.items()))
pre_vL = [self.attention_history(e2dLD[e], q, train=train) for e in eL]
v_eDq = self.W_dv(F.concat(pre_vL, axis=0))
answer_idx = eL.index(answer)
p = self.predict_answer(u_Dq, v_eDq, [True if token in query else False for token in eL], train=train) + EPS
t = chainer.Variable(self.xp.array([answer_idx]).astype(np.int32), volatile=not train)
accum_loss_doc += F.softmax_cross_entropy(p, t)
p_data = p.data[0, :]
max_idx = self.xp.argmax(p_data)
TorFs += (max_idx == answer_idx)
if max_idx != answer_idx:
for sub_ans in [k for k, e in enumerate(eL) if e in query]:
p_data[sub_ans] = -10000000
subTorFs += (self.xp.argmax(p_data) == answer_idx)
return accum_loss_doc, TorFs, subTorFs
评论列表
文章目录