beam_search.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ImageCaptioning 作者: rkuga 项目源码 文件源码
def beam_search(dec,state,y,data,beam_width,mydict_inv):  
    beam_width=beam_width
    xp=cuda.cupy
    batchsize=data.shape[0]
    vocab_size=len(mydict_inv)
    topk=20
    route = np.zeros((batchsize,beam_width,50)).astype(np.int32)

    for j in range(50):
        if j == 0:
            y = Variable(xp.array(np.argmax(y.data.get(), axis=1)).astype(xp.int32))
            state,y = dec(y, state, train=False)
            h=state['h1'].data
            c=state['c1'].data
            h=xp.tile(h.reshape(batchsize,1,-1), (1,beam_width,1))
            c=xp.tile(c.reshape(batchsize,1,-1), (1,beam_width,1))
            ptr=F.log_softmax(y).data.get()
            pred_total_city = np.argsort(ptr)[:,::-1][:,:beam_width]
            pred_total_score = np.sort(ptr)[:,::-1][:,:beam_width]
            route[:,:,j] = pred_total_city
            pred_total_city=pred_total_city.reshape(batchsize,beam_width,1)
        else:
            pred_next_score=np.zeros((batchsize,beam_width,topk))
            pred_next_city=np.zeros((batchsize,beam_width,topk)).astype(np.int32)
            score2idx=np.zeros((batchsize,beam_width,topk)).astype(np.int32)
            for b in range(beam_width):
                state={'c1':Variable(c[:,b,:]), 'h1':Variable(h[:,b,:])}
                cur_city = xp.array([pred_total_city[i,b,j-1] for i in range(batchsize)]).astype(xp.int32)
                state,y = dec(cur_city,state, train=False)
                h[:,b,:]=state['h1'].data
                c[:,b,:]=state['c1'].data
                ptr=F.log_softmax(y).data.get()
                pred_next_score[:,b,:]=np.sort(ptr, axis=1)[:,::-1][:,:topk]
                pred_next_city[:,b,:]=np.argsort(ptr, axis=1)[:,::-1][:,:topk]

            h=F.stack([h for i in range(topk)], axis=2).data
            c=F.stack([c for i in range(topk)], axis=2).data

            pred_total_city = np.tile(route[:,:,:j],(1,1,topk)).reshape(batchsize,beam_width,topk,j)
            pred_next_city = pred_next_city.reshape(batchsize,beam_width,topk,1)
            pred_total_city = np.concatenate((pred_total_city,pred_next_city),axis=3)

            pred_total_score = np.tile(pred_total_score.reshape(batchsize,beam_width,1),(1,1,topk)).reshape(batchsize,beam_width,topk,1)
            pred_next_score = pred_next_score.reshape(batchsize,beam_width,topk,1)
            pred_total_score += pred_next_score

            idx = pred_total_score.reshape(batchsize,beam_width * topk).argsort(axis=1)[:,::-1][:,:beam_width]

            pred_total_city = pred_total_city[:,idx//topk, np.mod(idx,topk), :][np.diag_indices(batchsize,ndim=2)].reshape(batchsize,beam_width,j+1)
            pred_total_score = pred_total_score[:,idx//topk, np.mod(idx,topk), :][np.diag_indices(batchsize,ndim=2)].reshape(batchsize,beam_width,1)
            h = h[:,idx//topk, np.mod(idx,topk), :][np.diag_indices(batchsize,ndim=2)].reshape(batchsize,beam_width,-1)
            c = c[:,idx//topk, np.mod(idx,topk), :][np.diag_indices(batchsize,ndim=2)].reshape(batchsize,beam_width,-1)

            route[:,:,:j+1] =pred_total_city
            if (pred_total_city[:,:,j] == 15).all():
                break


    return route[:,0,:j+1].tolist()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号