train_lstm.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:taxi 作者: xuguanggen 项目源码 文件源码
def main():
    batches_per_epoch = 250
    generate_size = 200
    nb_epoch = 20
    print('1. Loading data.............')
    te_con_feature,te_emb_feature,te_seq_feature,vocabs_size = load_test_dataset()

    n_con = te_con_feature.shape[1]
    n_emb = te_emb_feature.shape[1]
    print('1.1 merge con_feature,emb_feature,seq_feature.....')
    test_feature = prepare_inputX(te_con_feature,te_emb_feature,te_seq_feature)

    print('2. cluster.........')
    cluster_centers = h5py.File('cluster.h5','r')['cluster'][:]

    print('3. Building model..........')
    model = build_lstm(n_con,n_emb,vocabs_size,dis_size,emb_size,cluster_centers.shape[0])
    checkPoint = ModelCheckpoint('weights/' + model_name +'.h5',save_best_only=True)
    earlystopping = EarlyStopping(patience = 500)
    model.compile(loss=hdist,optimizer='rmsprop') #[loss = 'mse',optimizer= Adagrad]
    tr_generator = train_generator(generate_size)
    model.fit_generator(
        tr_generator,
        samples_per_epoch = batches_per_epoch* generate_size,
        nb_epoch = nb_epoch,
        validation_data = getValData(),
        verbose = 1,
        callbacks = [checkPoint,earlystopping]
    )

    print('4. Predicting result .............')
    te_predict = model.predict(test_feature)
    save_results(te_predict,result_csv_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号