train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:dong_iccv_2017 作者: woozzu 项目源码 文件源码
def preprocess(img, desc, len_desc, txt_encoder):
    img = Variable(img.cuda() if not args.no_cuda else img)
    desc = Variable(desc.cuda() if not args.no_cuda else desc)

    len_desc = len_desc.numpy()
    sorted_indices = np.argsort(len_desc)[::-1]
    original_indices = np.argsort(sorted_indices)
    packed_desc = nn.utils.rnn.pack_padded_sequence(
        desc[sorted_indices, ...].transpose(0, 1),
        len_desc[sorted_indices]
    )
    _, txt_feat = txt_encoder(packed_desc)
    txt_feat = txt_feat.squeeze()
    txt_feat = txt_feat[original_indices, ...]

    txt_feat_np = txt_feat.data.cpu().numpy() if not args.no_cuda else txt_feat.data.numpy()
    txt_feat_mismatch = torch.Tensor(np.roll(txt_feat_np, 1, axis=0))
    txt_feat_mismatch = Variable(txt_feat_mismatch.cuda() if not args.no_cuda else txt_feat_mismatch)
    txt_feat_np_split = np.split(txt_feat_np, [txt_feat_np.shape[0] // 2])
    txt_feat_relevant = torch.Tensor(np.concatenate([
        np.roll(txt_feat_np_split[0], -1, axis=0),
        txt_feat_np_split[1]
    ]))
    txt_feat_relevant = Variable(txt_feat_relevant.cuda() if not args.no_cuda else txt_feat_relevant)
    return img, txt_feat, txt_feat_mismatch, txt_feat_relevant
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号