def sample(vocab, video_feat, decoder, video_path, vid):
# ?????????????????
img_dir = os.path.join(visual_dir, str(vid))
if not os.path.exists(img_dir):
os.mkdir(img_dir)
frame_list = open_video(video_path)
if use_cuda:
video_feat = video_feat.cuda()
video_feat = video_feat.unsqueeze(0)
outputs, attens = decoder.sample(video_feat)
words = []
for i, token in enumerate(outputs.data.squeeze()):
if token == vocab('<end>'):
break
word = vocab.idx2word[token]
print(word)
words.append(word)
v, k = torch.topk(attens[i], 5)
# pair = zip(v.data[0], k.data[0])
# print(pair)
selected_id = k.data[0][0]
selected_frame = frame_list[selected_id]
cv2.imshow('Attend', selected_frame)
cv2.imwrite(os.path.join(img_dir, '%d_%d_%s.jpg' % (i, selected_id,
word)), selected_frame)
# ???????
sal = psal.get_saliency_rbd(selected_frame).astype('uint8')
cv2.imwrite(os.path.join(img_dir, '%d_%d_%s.jpg' % (i, selected_id,
'saliency')), sal)
binary_sal = psal.binarise_saliency_map(sal, method='adaptive')
I = binary_sal[:, :, np.newaxis]
binary_mask = np.concatenate((I, I, I), axis=2)
foreground_img = np.multiply(selected_frame, binary_mask).astype('uint8')
cv2.imwrite(os.path.join(img_dir, '%d_%d_%s.jpg' % (i, selected_id,
'foreground')), foreground_img)
k = cv2.waitKey(500)
if k == ord('n'):
return
caption = ' '.join(words)
print(caption)
评论列表
文章目录