main.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:crnn_tf 作者: liuhu-bigeye 项目源码 文件源码
def main():
    if len(sys.argv) == 3:
        config = Config(sys.argv[1], sys.argv[2])
    else:
        assert False

    phase = config.items['phase']
    from reader import Reader
    train_set = Reader(phase='train', batch_size=config.items['batch_size'], do_shuffle=True)
    valid_set = Reader(phase='val', batch_size=10, do_shuffle=False)
    test_set = Reader(phase='test', batch_size=10, do_shuffle=False)


    glog.info('generating model...')
    from model import Model

    # with tf.device('/cpu:0'):
    # with tf.device('/gpu:%d'%config.items['gpu']):
    model = Model(config.items['lr'])

    # try:
    #     config.items['starting'] = int(config.items['model'].split('_')[-1])
    # except:
    config.items['starting'] = 0

    # snapshot path
    mkdir_safe(config.items['snap_path'])

    sess_config = tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 4})
    sess_config.gpu_options.allow_growth = True


    with tf.Session(config=sess_config) as sess:
        tf.global_variables_initializer().run()
        if 'model' in config.items.keys():
            model.saver.restore(sess, config.items['model'])
            glog.info('loading model: %s...' % config.items['model'])
        if phase == 'ctc':
            glog.info('ctc training...')
            train_valid(sess, model, train_set, valid_set, test_set, config)
        # elif phase == 'extract_feature':
        #     pass
        # elif phase == 'get_prediction':
        #     from reader import Reader
        #     train_set = Reader(phase='train', batch_size=config.items['batch_size'], do_shuffle=False, resample=False, distortion=False)
        #     glog.info('feature extracting...')
        #     get_prediction(model, train_set, config)
        # elif phase == 'top_k_prediction':
        #     from reader import Reader
        #     train_set = Reader(phase='test', batch_size=config.items['batch_size'], do_shuffle=False, resample=False, distortion=False)
        #     glog.info('feature extracting...')
        #     get_top_k_prediction(model, train_set, config)

    glog.info('end')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号