def forward(self, x, x_lens, ys, ys_lens, xys_idx):
x = self.embeddings(x)
h = self._encode_embed(x, x_lens)
if self.batch_first:
ys = ys.transpose(1, 0)
ys_lens = ys_lens.transpose(1, 0)
xys_idx = xys_idx.transpose(1, 0)
logits_list = []
for dec_idx, (y, y_lens, xy_idx) in enumerate(
zip(ys, ys_lens, xys_idx)):
h_dec = torch.index_select(h, 0, xy_idx)
logits = self._decode(dec_idx, h_dec, y, y_lens)
nil_batches = len(h_dec) - len(logits)
if nil_batches:
logits = pad_batch(logits, nil_batches, True)
logits_list.append(logits.unsqueeze(0))
logits = torch.cat(logits_list)
if self.batch_first:
logits = logits.transpose(1, 0)
return logits, h
评论列表
文章目录