sample.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:ssta-captioning 作者: Yugnaynehc 项目源码 文件源码
def sample(vocab, video_feat, decoder, video_path, vid):
    # ?????????????????
    img_dir = os.path.join(visual_dir, str(vid))
    if not os.path.exists(img_dir):
        os.mkdir(img_dir)

    frame_list = open_video(video_path)
    if use_cuda:
        video_feat = video_feat.cuda()
    video_feat = video_feat.unsqueeze(0)
    outputs, attens = decoder.sample(video_feat)
    words = []
    for i, token in enumerate(outputs.data.squeeze()):
        if token == vocab('<end>'):
            break
        word = vocab.idx2word[token]
        print(word)
        words.append(word)
        v, k = torch.topk(attens[i], 5)
        # pair = zip(v.data[0], k.data[0])
        # print(pair)
        selected_id = k.data[0][0]
        selected_frame = frame_list[selected_id]
        cv2.imshow('Attend', selected_frame)
        cv2.imwrite(os.path.join(img_dir, '%d_%d_%s.jpg' % (i, selected_id,
                                                            word)), selected_frame)

        # ???????
        sal = psal.get_saliency_rbd(selected_frame).astype('uint8')
        cv2.imwrite(os.path.join(img_dir, '%d_%d_%s.jpg' % (i, selected_id,
                                                            'saliency')), sal)

        binary_sal = psal.binarise_saliency_map(sal, method='adaptive')
        I = binary_sal[:, :, np.newaxis]
        binary_mask = np.concatenate((I, I, I), axis=2)
        foreground_img = np.multiply(selected_frame, binary_mask).astype('uint8')
        cv2.imwrite(os.path.join(img_dir, '%d_%d_%s.jpg' % (i, selected_id,
                                                            'foreground')), foreground_img)

        k = cv2.waitKey(500)
        if k == ord('n'):
            return
    caption = ' '.join(words)
    print(caption)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号