run_ngtm.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:neural_topic_models 作者: dallascard 项目源码 文件源码
def save_mean_representations(model, model_filename, X, labels, pred_file):
    n_items, dv = X.shape
    n_classes = model.n_classes
    n_topics = model.d_t

    # try normalizing input vectors
    test_X = normalize(np.array(X, dtype='float32'), axis=1)

    model.load_params(model_filename)

    # evaluate bound on test set
    item_mus = []
    for item in range(n_items):
        y = labels[item]

        # save the mean document representation
        r_mu = model.get_mean_doc_rep(test_X[item, :], y)
        item_mus.append(np.array(r_mu))

    # write all the test doc representations to file
    if pred_file is not None and n_topics > 1:
        np.savez_compressed(pred_file, X=np.array(item_mus), y=labels)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号