def decode(self, input_word, input_char, input_pos, mask=None, length=None, hx=None, beam=1, leading_symbolic=0, ordered=True):
# reset noise for decoder
self.decoder.reset_noise(0)
# output from encoder [batch, length_encoder, tag_space]
# src_encoding [batch, length, input_size]
# arc_c [batch, length, arc_space]
# type_c [batch, length, type_space]
# hn [num_direction, batch, hidden_size]
src_encoding, output_enc, hn, mask, length = self._get_encoder_output(input_word, input_char, input_pos, mask_e=mask, length_e=length, hx=hx)
# output size [batch, length_encoder, arc_space]
arc_c = F.elu(self.arc_c(output_enc))
# output size [batch, length_encoder, type_space]
type_c = F.elu(self.type_c(output_enc))
hn = self._transform_decoder_init_state(hn)
batch, max_len_e, _ = src_encoding.size()
heads = np.zeros([batch, max_len_e], dtype=np.int32)
types = np.zeros([batch, max_len_e], dtype=np.int32)
children = np.zeros([batch, 2 * max_len_e - 1], dtype=np.int32)
stack_types = np.zeros([batch, 2 * max_len_e - 1], dtype=np.int32)
for b in range(batch):
sent_len = None if length is None else length[b]
# hack to handle LSTM
if isinstance(hn, tuple):
hx, cx = hn
hx = hx[:, b, :].contiguous()
cx = cx[:, b, :].contiguous()
hx = (hx, cx)
else:
hx = hn[:, b, :].contiguous()
preds = self._decode_per_sentence(src_encoding[b], output_enc[b], arc_c[b], type_c[b], hx, sent_len, beam, ordered, leading_symbolic)
if preds is None:
preds = self._decode_per_sentence(src_encoding[b], output_enc[b], arc_c[b], type_c[b], hx, sent_len, beam, False, leading_symbolic)
hids, tids, sent_len, chids, stids = preds
heads[b, :sent_len] = hids
types[b, :sent_len] = tids
children[b, :2 * sent_len - 1] = chids
stack_types[b, :2 * sent_len - 1] = stids
return heads, types, children, stack_types
评论列表
文章目录