def preprocess(img, desc, len_desc, txt_encoder):
img = Variable(img.cuda() if not args.no_cuda else img)
desc = Variable(desc.cuda() if not args.no_cuda else desc)
len_desc = len_desc.numpy()
sorted_indices = np.argsort(len_desc)[::-1]
original_indices = np.argsort(sorted_indices)
packed_desc = nn.utils.rnn.pack_padded_sequence(
desc[sorted_indices, ...].transpose(0, 1),
len_desc[sorted_indices]
)
_, txt_feat = txt_encoder(packed_desc)
txt_feat = txt_feat.squeeze()
txt_feat = txt_feat[original_indices, ...]
txt_feat_np = txt_feat.data.cpu().numpy() if not args.no_cuda else txt_feat.data.numpy()
txt_feat_mismatch = torch.Tensor(np.roll(txt_feat_np, 1, axis=0))
txt_feat_mismatch = Variable(txt_feat_mismatch.cuda() if not args.no_cuda else txt_feat_mismatch)
txt_feat_np_split = np.split(txt_feat_np, [txt_feat_np.shape[0] // 2])
txt_feat_relevant = torch.Tensor(np.concatenate([
np.roll(txt_feat_np_split[0], -1, axis=0),
txt_feat_np_split[1]
]))
txt_feat_relevant = Variable(txt_feat_relevant.cuda() if not args.no_cuda else txt_feat_relevant)
return img, txt_feat, txt_feat_mismatch, txt_feat_relevant
评论列表
文章目录