eval.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DREAM 作者: LaceyChen17 项目源码 文件源码
def get_item_embedding(pid, dr_model):
    '''
        get item's embedding
        pid can be a integer or a torch.cuda.LongTensor
    '''
    if isinstance(pid, torch.cuda.LongTensor) or isinstance(pid, torch.LongTensor):
        return dr_model.encode.weight[pid]
    elif isinstance(pid, int):
        return dr_model.encode.weight[pid].unsqueeze(0)
    else:
        print('Unsupported Index Type %s'%type(pid))
        return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号