base_parser.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Sing_Par 作者: wanghm92 项目源码 文件源码
def validate(self, mb_inputs, mb_targets, mb_probs):
    """"""

    sents = []
    mb_parse_probs, mb_rel_probs = mb_probs
    for inputs, targets, parse_probs, rel_probs in zip(mb_inputs, mb_targets, mb_parse_probs, mb_rel_probs):
      tokens_to_keep = np.greater(inputs[:,0], Vocab.ROOT)
      length = np.sum(tokens_to_keep)
      parse_preds, rel_preds = self.prob_argmax(parse_probs, rel_probs, tokens_to_keep)

      sent = -np.ones( (length, 9), dtype=int)
      tokens = np.arange(1, length+1)
      sent[:,0] = tokens
      sent[:,1:4] = inputs[tokens]
      sent[:,4] = targets[tokens,0]
      sent[:,5] = parse_preds[tokens]
      sent[:,6] = rel_preds[tokens]
      sent[:,7:] = targets[tokens, 1:]
      sents.append(sent)
    return sents

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号