n11_train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:kaggle_yt8m 作者: N01Z3 项目源码 文件源码
def get_mod(ags):
    dst = os.path.join(ags.wpath, ags.versn)
    b_scr = -1

    if ags.optim == 'adam':
        opt = adam(ags.lrate)
    elif ags.optim == 'sgd':
        opt = sgd(ags.lrate)
    else:
        opt = adam()

    lst = [build_mod2(), build_mod3(), build_mod7(), build_mod9(), build_mod11(), build_mod12(), build_mod13()]

    model = lst[ags.mtype]
    if ags.mtype == 0:
        model = build_mod2(opt)
        logging.info('start with model 2')
    elif ags.mtype == 1:
        model = build_mod3(opt)
        logging.info('start with model 3')
    elif ags.mtype == 2:
        model = build_mod7(opt)
        logging.info('start with model 7')
    elif ags.mtype == 3:
        model = build_mod9(opt)
        logging.info('start with model 9')
    elif ags.mtype == 4:
        model = build_mod11(opt)
        logging.info('start with model 11')
    elif ags.mtype == 5:
        model = build_mod12(opt)
        logging.info('start with model 12')
    elif ags.mtype == 6:
        model = build_mod13(opt)
        logging.info('start with model 13')

    if ags.begin == -1:
        fls = sorted(glob.glob(dst + '/*h5'))
        if len(fls) > 0:
            logging.info('load weights: %s' % fls[-1])
            model.load_weights(fls[-1])
            b_scr = float(os.path.basename(fls[-1]).split('_')[0])

    return model, b_scr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号