train_vae.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:KATE 作者: hugochan 项目源码 文件源码
def train(args):
    corpus = load_corpus(args.input)
    n_vocab, docs = len(corpus['vocab']), corpus['docs']
    corpus.clear() # save memory

    X_docs = []
    for k in docs.keys():
        X_docs.append(vecnorm(doc2vec(docs[k], n_vocab), 'logmax1', 0))
        del docs[k]

    np.random.seed(0)
    np.random.shuffle(X_docs)
    # X_docs_noisy = corrupted_matrix(np.r_[X_docs], 0.1)

    n_val = args.n_val
    # X_train = np.r_[X_docs[:-n_val]]
    # X_val = np.r_[X_docs[-n_val:]]
    X_train = np.r_[X_docs[:-n_val]]
    del X_docs[:-n_val]
    X_val = np.r_[X_docs]
    del X_docs

    start = timeit.default_timer()

    vae = VarAutoEncoder(n_vocab, args.n_dim, comp_topk=args.comp_topk, ctype=args.ctype, save_model=args.save_model)
    vae.fit([X_train, X_train], [X_val, X_val], nb_epoch=args.n_epoch, batch_size=args.batch_size)

    print 'runtime: %ss' % (timeit.default_timer() - start)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号