parse03.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:nn_parsers 作者: odashi 项目源码 文件源码
def forward(self, data):
    ep_list = [self.p_embed(d[0], d[1]) for d in data]
    ec_list = [self.c_embed(d[0], d[1]) for d in data]
    er_list = [self.r_embed(d[0], d[1]) for d in data]
    p_list = self.p_encode(ep_list)
    c_list = self.c_encode(ec_list)
    r_list = self.r_encode(er_list)

    P = functions.reshape(
      functions.concat(p_list, 0),
      (1, len(data), self.hidden_size))
    C = functions.reshape(
      functions.concat(c_list, 0),
      (1, len(data), self.hidden_size))
    R = functions.concat(r_list, 0)

    parent_scores = functions.reshape(
      functions.batch_matmul(C, P, transb=True),
      (len(data), len(data)))
    root_scores = functions.reshape(
      self.r_scorer(R),
      (1, len(data)))

    return parent_scores, root_scores
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号