def main(_):
attrs = conf.__dict__['__flags']
pp(attrs)
dataset, img_feature, train_data = get_data(conf.input_json, conf.input_img_h5, conf.input_ques_h5, conf.img_norm)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=calc_gpu_fraction(conf.gpu_fraction))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
model = question_generator.Question_Generator(sess, conf, dataset, img_feature, train_data)
if conf.is_train:
model.build_model()
model.train()
else:
model.build_generator()
model.test(test_image_path=conf.test_image_path, model_path=conf.test_model_path, maxlen=26)
评论列表
文章目录