def vectorize(self, x, x_lens):
x, x_lens, r_idx = self.prepare_batch(x, x_lens)
if len(x.size()) == 2:
vectors = self.model.encode(x, x_lens)
elif len(x.size()) == 3:
vectors = self.model.encode_embed(x, x_lens)
else:
raise Exception()
vectors = torch.index_select(vectors, 0, r_idx)
if self.is_cuda:
vectors = vectors.cpu()
return vectors.data
评论列表
文章目录