caption_gen.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Caption-Generation 作者: m516825 项目源码 文件源码
def main(_):
    print("\nParameters: ")
    for k, v in sorted(FLAGS.__flags.items()):
        print("{} = {}".format(k, v))

    if not os.path.exists("./prepro/"):
        os.makedirs("./prepro/")

    if FLAGS.eval:
        print("Evaluation...")
    else:
        if FLAGS.prepro:
            print ("Start preprocessing data...")
            vocab_processor, train_dict = data_utils.load_text_data(train_lab=FLAGS.train_lab, 
                                                         prepro_train_p=FLAGS.prepro_train, vocab_path=FLAGS.vocab)
            print ("Vocabulary size: {}".format(len(vocab_processor._reverse_mapping)))

            print ("Start dumping word2vec matrix...")
            w2v_W = data_utils.build_w2v_matrix(vocab_processor, FLAGS.w2v_data, FLAGS.vector_file, FLAGS.embedding_dim)

        else:
            train_dict = cPickle.load(open(FLAGS.prepro_train, 'rb'))
            vocab_processor = VocabularyProcessor.restore(FLAGS.vocab)
            w2v_W = cPickle.load(open(FLAGS.w2v_data, 'rb'))

        print("Start generating training data...")
        feats, encoder_in_idx, decoder_in = data_utils.gen_train_data(FLAGS.train_dir, FLAGS.train_lab, train_dict)
        print("Start generating validation data...")
        v_encoder_in, truth_captions = data_utils.load_valid(FLAGS.valid_dir, FLAGS.valid_lab)

        t_encoder_in = None
        files = None
        if FLAGS.task_dir != None:
            t_encoder_in, files = data_utils.load_task(FLAGS.task_dir)

        print('feats size: {}, training size: {}'.format(len(feats), len(encoder_in_idx)))
        print(encoder_in_idx.shape, decoder_in.shape)
        print(v_encoder_in.shape, len(truth_captions))

        data = Data(feats, encoder_in_idx, decoder_in, v_encoder_in, truth_captions, t_encoder_in, files)

        model = CapGenModel(data, w2v_W, vocab_processor)

        model.build_model()

        model.train()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号