train_mlp.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:taxi 作者: xuguanggen 项目源码 文件源码
def main(result_csv_path,hasCluster):
    print('1. Loading Data.........')
    tr_con_feature,tr_emb_feature,tr_label,te_con_feature,te_emb_feature,vocabs_size = load_dataset()

    n_con = tr_con_feature.shape[1]
    n_emb = tr_emb_feature.shape[1]

    train_x = prepare_inputX(tr_con_feature,tr_emb_feature)
    test_x = prepare_inputX(te_con_feature,te_emb_feature)
    print('1.1 cluster.............')
    cluster_centers = []
    if hasCluster:
        f = h5py.File('cluster.h5','r')
        cluster_centers = f['cluster'][:]
    else:
        cluster_centers = cluster()

    print('2. Building model..........')
    model = build_mlp(n_con,n_emb,vocabs_size,dis_size,emb_size,cluster_centers.shape[0])
    checkPoint = ModelCheckpoint('weights/' + model_name +'.h5',save_best_only=True)
    model.compile(loss=hdist,optimizer='rmsprop') #[loss = 'mse',optimizer= Adagrad]
    model.fit(
        train_x,
        tr_label,
        nb_epoch = 2000, #1000 # 1500
        batch_size = 500, # 500 #400
        verbose = 1,
        validation_split = 0.3,
        callbacks =([checkPoint])
    )
    ##### dump model ########
    #json_string = model.to_json()
    #open('weights/'+ model_name +'.json','w').write(json_string)
    #model.save_weights('weights/'+ model_name + '.h5',overwrite=True,)

    ####### predict #############################
    print('3. Predicting result.........')
    te_predict = model.predict(test_x)
    df_test = pd.read_csv(Test_CSV_Path,header=0)
    result = pd.DataFrame()
    result['TRIP_ID'] = df_test['TRIP_ID']
    result['LATITUDE'] = te_predict[:,1]
    result['LONGITUDE'] = te_predict[:,0]
    result.to_csv(result_csv_path,index=False)
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号