generate.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:lesson 作者: SPJ-AI 项目源码 文件源码
def get_index_a(_model):
    _model.predictor.reset_state()
    _sentence_index_a = []
    index = BOS_INDEX
    while index != EOS_INDEX:
        y = _model.predictor(xp.array([index], dtype=xp.int32))
        probability = F.softmax(y)
        probability.data[0] /= sum(probability.data[0])
        try:
            #???????????????????
            #index = np.argmax(probability.data[0])
            index = xp.random.choice(range(len(probability.data[0])), p=probability.data[0])
            if index!=EOS_INDEX:
                #??<EOS>???????
                _sentence_index_a.append(index)
        except Exception as e:
            print('probability error')
            break

    return _sentence_index_a
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号