lstmUtils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:RNNIPTag 作者: ml-slac 项目源码 文件源码
def BuildModel():

    #global dataset_storage,model_storage,history_storage

    dataset = makeData (Variables = o.Variables)
    #dataset_storage = dataset

    model = None
    history = None
    modelname = "" 
    print o.Model
    if "LSTM" in o.Model or "GRU" in o.Model:
        model, history = buildModel_1hidden(dataset,True)
    if o.Model == "RNNSV1":
        model, history = buildModel_RNNSV1(dataset, True)
    if o.Model == "DenseIP3D":
        model, history = buildModel_SimpleDense(dataset, False)
    print ' ------------------------------------------'
    print o.Model
    if o.Model == "RNNPlusMV2" or o.Model == "RNNPlusSV1":
        model, history = buildModel_RNNPlus(dataset, useAdam=True)

    modelname = o.Version +"_" + o.Model + "_"+ o.Variables + "_" + o.nEpoch + "epoch_" + str( n_events/1000) + 'kEvts_' + str( o.nTrackCut) + 'nTrackCut_' +  o.nMaxTrack + "nMaxTrack_" + o.nLSTMClass +"nLSTMClass_" + o.nLSTMNodes +"nLSTMNodes_"+ o.nLayers + "nLayers"

    model = evalModel(dataset, model, o.Model)

    if o.TrackOrder == 'pT':
        modelname += "_SortpT"
    if o.TrackOrder == 'Reverse':
        modelname += "_ReverseOrder"
    if o.TrackOrder == 'SL0':
        modelname += "_SL0"
    if o.doTrainC == 'y':
        modelname += "_CMix"
    if o.AddJetpT == 'y':
        modelname += '_AddJetpT'
    if int(o.EmbedSize) != 2:
        modelname += "_" + o.EmbedSize+"EmbedSize"

    if o.Mode == "R":
        modelname = o.filebase+"_Retrain_"+o.nEpoch
    if o.doLessC == "y":
        modelname += "_LessC"

    if o.doJetpTReweight == "y":
        modelname += "_JetpTReweight"

    #modelname = "test"
    saveModel(modelname, model, history)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号