main.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:crnn_tf 作者: liuhu-bigeye 项目源码 文件源码
def main():
    if len(sys.argv) == 3:
        config = Config(sys.argv[1], sys.argv[2])
    else:
        assert False
    from utils import mkdir_safe, log_self
    log_self(__file__)

    glog.info('generating model...')
    from model_after import Model
    model = Model(learning_rate=config.items['lr'], config=config)

    # load model
    if 'model' in config.items.keys():
        glog.info('loading model: %s...' % config.items['model'])
        model.load_model(config.items['model'])
    elif 'model_old' in config.items.keys():
        glog.info('loading model from old: %s...' % config.items['model_old'])
        model.load_old_model(config.items['model_old'])

    from reader import Reader
    train_set = Reader(phase='train', config=config, do_shuffle=True, resample=True)
    valid_set = Reader(phase='dev', config=config, do_shuffle=True, resample=False, feature_mean=train_set.feature_mean, feature_std=train_set.feature_std)
    test_set = Reader(phase='test', config=config, do_shuffle=True, resample=False, feature_mean=train_set.feature_mean, feature_std=train_set.feature_std)

    try:
        config.items['starting'] = int(config.items['model'].split('_')[-1])
    except:
        config.items['starting'] = 0

    if 'predict' in config.items.keys():
        prob_predict(model, train_set, config, epoch=config.items['starting'])
        prob_predict(model, valid_set, config, epoch=config.items['starting'])
        prob_predict(model, test_set, config, epoch=config.items['starting'])
        return

    # snapshot path
    mkdir_safe(config.items['snap_path'])
    mkdir_safe(os.path.join(config.items['snap_path'], 'output_dev'))
    mkdir_safe(os.path.join(config.items['snap_path'], 'output_test'))

    glog.info('training...')
    train_valid(model, train_set, valid_set, test_set, config)

    glog.info('end')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号