def BuildModel():
#global dataset_storage,model_storage,history_storage
dataset = makeData (Variables = o.Variables)
#dataset_storage = dataset
model = None
history = None
modelname = ""
print o.Model
if "LSTM" in o.Model or "GRU" in o.Model:
model, history = buildModel_1hidden(dataset,True)
if o.Model == "RNNSV1":
model, history = buildModel_RNNSV1(dataset, True)
if o.Model == "DenseIP3D":
model, history = buildModel_SimpleDense(dataset, False)
print ' ------------------------------------------'
print o.Model
if o.Model == "RNNPlusMV2" or o.Model == "RNNPlusSV1":
model, history = buildModel_RNNPlus(dataset, useAdam=True)
modelname = o.Version +"_" + o.Model + "_"+ o.Variables + "_" + o.nEpoch + "epoch_" + str( n_events/1000) + 'kEvts_' + str( o.nTrackCut) + 'nTrackCut_' + o.nMaxTrack + "nMaxTrack_" + o.nLSTMClass +"nLSTMClass_" + o.nLSTMNodes +"nLSTMNodes_"+ o.nLayers + "nLayers"
model = evalModel(dataset, model, o.Model)
if o.TrackOrder == 'pT':
modelname += "_SortpT"
if o.TrackOrder == 'Reverse':
modelname += "_ReverseOrder"
if o.TrackOrder == 'SL0':
modelname += "_SL0"
if o.doTrainC == 'y':
modelname += "_CMix"
if o.AddJetpT == 'y':
modelname += '_AddJetpT'
if int(o.EmbedSize) != 2:
modelname += "_" + o.EmbedSize+"EmbedSize"
if o.Mode == "R":
modelname = o.filebase+"_Retrain_"+o.nEpoch
if o.doLessC == "y":
modelname += "_LessC"
if o.doJetpTReweight == "y":
modelname += "_JetpTReweight"
#modelname = "test"
saveModel(modelname, model, history)
评论列表
文章目录