main.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:VQG-tensorflow 作者: JamesChuanggg 项目源码 文件源码
def main(_):

    attrs = conf.__dict__['__flags']
    pp(attrs)

    dataset, img_feature, train_data = get_data(conf.input_json, conf.input_img_h5, conf.input_ques_h5, conf.img_norm)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=calc_gpu_fraction(conf.gpu_fraction))

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        model = question_generator.Question_Generator(sess, conf, dataset, img_feature, train_data)

        if conf.is_train:
            model.build_model()
        model.train()
    else:
        model.build_generator()
        model.test(test_image_path=conf.test_image_path, model_path=conf.test_model_path, maxlen=26)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号