def train(args):
corpus = load_corpus(args.input)
n_vocab, docs = len(corpus['vocab']), corpus['docs']
corpus.clear() # save memory
X_docs = []
for k in docs.keys():
X_docs.append(vecnorm(doc2vec(docs[k], n_vocab), 'logmax1', 0))
del docs[k]
np.random.seed(0)
np.random.shuffle(X_docs)
# X_docs_noisy = corrupted_matrix(np.r_[X_docs], 0.1)
n_val = args.n_val
# X_train = np.r_[X_docs[:-n_val]]
# X_val = np.r_[X_docs[-n_val:]]
X_train = np.r_[X_docs[:-n_val]]
del X_docs[:-n_val]
X_val = np.r_[X_docs]
del X_docs
start = timeit.default_timer()
vae = VarAutoEncoder(n_vocab, args.n_dim, comp_topk=args.comp_topk, ctype=args.ctype, save_model=args.save_model)
vae.fit([X_train, X_train], [X_val, X_val], nb_epoch=args.n_epoch, batch_size=args.batch_size)
print 'runtime: %ss' % (timeit.default_timer() - start)
评论列表
文章目录