def val_sents(self, data, dec_logits):
vocab, previews = self.model.vocab, self.previews
x, x_lens, ys_i, ys_t, ys_lens, xys_idx = data
if self.batch_first:
cdata = [ys_i, ys_t, ys_lens, xys_idx, dec_logits]
cdata = [d.transpose(1, 0).contiguous() for d in cdata]
ys_i, ys_t, ys_lens, xys_idx, dec_logits = cdata
_, xys_ridx = torch.sort(xys_idx, 1)
xys_ridx_exp = xys_ridx.unsqueeze(-1).expand_as(ys_i)
ys_i = torch.gather(ys_i, 1, xys_ridx_exp)
ys_t = torch.gather(ys_t, 1, xys_ridx_exp)
dec_logits = [torch.index_select(logits, 0, xy_ridx)
for logits, xy_ridx in zip(dec_logits, xys_ridx)]
ys_lens = torch.gather(ys_lens, 1, xys_ridx)
x, x_lens = x[:previews], x_lens[:previews]
ys_i, ys_t = ys_i[:, :previews], ys_t[:, :previews]
dec_logits = torch.cat(
[logits[:previews].max(2)[1].squeeze(-1).unsqueeze(0)
for logits in dec_logits], 0)
ys_lens = ys_lens[:, :previews]
ys_i, ys_t = ys_i.transpose(1, 0), ys_t.transpose(1, 0)
dec_logits, ys_lens = dec_logits.transpose(1, 0), ys_lens.transpose(1,
0)
x, x_lens = x.data.tolist(), x_lens.data.tolist()
ys_i, ys_t = ys_i.data.tolist(), ys_t.data.tolist()
dec_logits, ys_lens = dec_logits.data.tolist(), ys_lens.data.tolist()
def to_sent(data, length, vocab):
return " ".join(vocab.i2f[data[i]] for i in range(length))
def to_sents(data, lens, vocab):
return [to_sent(d, l, vocab) for d, l in zip(data, lens)]
x_sents = to_sents(x, x_lens, vocab)
yi_sents = [to_sents(yi, y_lens, vocab) for yi, y_lens in
zip(ys_i, ys_lens)]
yt_sents = [to_sents(yt, y_lens, vocab) for yt, y_lens in
zip(ys_t, ys_lens)]
o_sents = [to_sents(dec_logit, y_lens, vocab)
for dec_logit, y_lens in zip(dec_logits, ys_lens)]
return x_sents, yi_sents, yt_sents, o_sents
评论列表
文章目录