def random_pick(p,word,sampling_type):
def weighted_pick(weights):
t = np.cumsum(weights)
s = np.sum(weights)
return(int(np.searchsorted(t, np.random.rand(1)*s)))
if sampling_type == 'argmax':
sample = np.argmax(p)
elif sampling_type == 'weighted':
sample = weighted_pick(p)
elif sampling_type == 'combined':
if word == ' ':
sample = weighted_pick(p)
else:
sample = np.argmax(p)
return sample
# test code
utils.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录