main.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:miccai17-mmwhs-hybrid 作者: xy0806 项目源码 文件源码
def main(_):
    # load training parameter #
    ini_file = '../outcome/model/ini/tr_param.ini'
    param_sets = load_train_ini(ini_file)
    param_set = param_sets[0]

    print '====== Phase >>> %s <<< ======' % param_set['phase']

    if not os.path.exists(param_set['chkpoint_dir']):
        os.makedirs(param_set['chkpoint_dir'])
    if not os.path.exists(param_set['labeling_dir']):
        os.makedirs(param_set['labeling_dir'])

    # GPU setting, per_process_gpu_memory_fraction means 95% GPU MEM ,allow_growth means unfixed memory
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95, allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) as sess:
        model = unet_3D_xy(sess, param_set)

        if param_set['phase'] == 'train':
            model.train()
        elif param_set['phase'] == 'test':
            # model.test()
            model.test_generate_map()
        elif param_set['phase'] == 'crsv':
            model.test4crsv()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号