inference.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:SERT 作者: cvangysel 项目源码 文件源码
def create(predict_fn, word_representations,
           batch_size, window_size, vocabulary_size,
           result_callback):
    assert result_callback is not None

    instance_dtype = np.min_scalar_type(vocabulary_size - 1)
    logging.info('Instance elements will be stored using %s.', instance_dtype)

    if result_callback.should_average_input():
        batcher = EmbeddingMapper(
            predict_fn,
            word_representations,
            result_callback)
    else:
        batcher = WordBatcher(
            predict_fn,
            batch_size, window_size,
            instance_dtype,
            result_callback)

    return batcher
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号