def get_item_embedding(pid, dr_model):
'''
get item's embedding
pid can be a integer or a torch.cuda.LongTensor
'''
if isinstance(pid, torch.cuda.LongTensor) or isinstance(pid, torch.LongTensor):
return dr_model.encode.weight[pid]
elif isinstance(pid, int):
return dr_model.encode.weight[pid].unsqueeze(0)
else:
print('Unsupported Index Type %s'%type(pid))
return None
评论列表
文章目录