def decode(self, input_word, input_char, input_pos, mask=None, length=None, hx=None, leading_symbolic=0):
# out_arc shape [batch, length, length]
out_arc, out_type, mask, length = self.forward(input_word, input_char, input_pos,
mask=mask, length=length, hx=hx)
out_arc = out_arc.data
batch, max_len, _ = out_arc.size()
# set diagonal elements to -inf
out_arc = out_arc + torch.diag(out_arc.new(max_len).fill_(-np.inf))
# set invalid positions to -inf
if mask is not None:
# minus_mask = (1 - mask.data).byte().view(batch, max_len, 1)
minus_mask = (1 - mask.data).byte().unsqueeze(2)
out_arc.masked_fill_(minus_mask, -np.inf)
# compute naive predictions.
# predition shape = [batch, length]
_, heads = out_arc.max(dim=1)
types = self._decode_types(out_type, heads, leading_symbolic)
return heads.cpu().numpy(), types.data.cpu().numpy()
评论列表
文章目录