def forward(self, data):
ep_list = [self.p_embed(d[0], d[1]) for d in data]
ec_list = [self.c_embed(d[0], d[1]) for d in data]
er_list = [self.r_embed(d[0], d[1]) for d in data]
p_list = self.p_encode(ep_list)
c_list = self.c_encode(ec_list)
r_list = self.r_encode(er_list)
P = functions.reshape(
functions.concat(p_list, 0),
(1, len(data), self.hidden_size))
C = functions.reshape(
functions.concat(c_list, 0),
(1, len(data), self.hidden_size))
R = functions.concat(r_list, 0)
parent_scores = functions.reshape(
functions.batch_matmul(C, P, transb=True),
(len(data), len(data)))
root_scores = functions.reshape(
self.r_scorer(R),
(1, len(data)))
return parent_scores, root_scores
评论列表
文章目录