def validate(self, mb_inputs, mb_targets, mb_probs):
""""""
sents = []
mb_parse_probs, mb_rel_probs = mb_probs
for inputs, targets, parse_probs, rel_probs in zip(mb_inputs, mb_targets, mb_parse_probs, mb_rel_probs):
tokens_to_keep = np.greater(inputs[:,0], Vocab.ROOT)
length = np.sum(tokens_to_keep)
parse_preds, rel_preds = self.prob_argmax(parse_probs, rel_probs, tokens_to_keep)
sent = -np.ones( (length, 9), dtype=int)
tokens = np.arange(1, length+1)
sent[:,0] = tokens
sent[:,1:4] = inputs[tokens]
sent[:,4] = targets[tokens,0]
sent[:,5] = parse_preds[tokens]
sent[:,6] = rel_preds[tokens]
sent[:,7:] = targets[tokens, 1:]
sents.append(sent)
return sents
#=============================================================
评论列表
文章目录