def main():
batches_per_epoch = 250
generate_size = 200
nb_epoch = 20
print('1. Loading data.............')
te_con_feature,te_emb_feature,te_seq_feature,vocabs_size = load_test_dataset()
n_con = te_con_feature.shape[1]
n_emb = te_emb_feature.shape[1]
print('1.1 merge con_feature,emb_feature,seq_feature.....')
test_feature = prepare_inputX(te_con_feature,te_emb_feature,te_seq_feature)
print('2. cluster.........')
cluster_centers = h5py.File('cluster.h5','r')['cluster'][:]
print('3. Building model..........')
model = build_lstm(n_con,n_emb,vocabs_size,dis_size,emb_size,cluster_centers.shape[0])
checkPoint = ModelCheckpoint('weights/' + model_name +'.h5',save_best_only=True)
earlystopping = EarlyStopping(patience = 500)
model.compile(loss=hdist,optimizer='rmsprop') #[loss = 'mse',optimizer= Adagrad]
tr_generator = train_generator(generate_size)
model.fit_generator(
tr_generator,
samples_per_epoch = batches_per_epoch* generate_size,
nb_epoch = nb_epoch,
validation_data = getValData(),
verbose = 1,
callbacks = [checkPoint,earlystopping]
)
print('4. Predicting result .............')
te_predict = model.predict(test_feature)
save_results(te_predict,result_csv_path)
评论列表
文章目录