def main(_):
print("\nParameters: ")
for k, v in sorted(FLAGS.__flags.items()):
print("{} = {}".format(k, v))
if not os.path.exists("./prepro/"):
os.makedirs("./prepro/")
if FLAGS.eval:
print("Evaluation...")
else:
if FLAGS.prepro:
print ("Start preprocessing data...")
vocab_processor, train_dict = data_utils.load_text_data(train_lab=FLAGS.train_lab,
prepro_train_p=FLAGS.prepro_train, vocab_path=FLAGS.vocab)
print ("Vocabulary size: {}".format(len(vocab_processor._reverse_mapping)))
print ("Start dumping word2vec matrix...")
w2v_W = data_utils.build_w2v_matrix(vocab_processor, FLAGS.w2v_data, FLAGS.vector_file, FLAGS.embedding_dim)
else:
train_dict = cPickle.load(open(FLAGS.prepro_train, 'rb'))
vocab_processor = VocabularyProcessor.restore(FLAGS.vocab)
w2v_W = cPickle.load(open(FLAGS.w2v_data, 'rb'))
print("Start generating training data...")
feats, encoder_in_idx, decoder_in = data_utils.gen_train_data(FLAGS.train_dir, FLAGS.train_lab, train_dict)
print("Start generating validation data...")
v_encoder_in, truth_captions = data_utils.load_valid(FLAGS.valid_dir, FLAGS.valid_lab)
t_encoder_in = None
files = None
if FLAGS.task_dir != None:
t_encoder_in, files = data_utils.load_task(FLAGS.task_dir)
print('feats size: {}, training size: {}'.format(len(feats), len(encoder_in_idx)))
print(encoder_in_idx.shape, decoder_in.shape)
print(v_encoder_in.shape, len(truth_captions))
data = Data(feats, encoder_in_idx, decoder_in, v_encoder_in, truth_captions, t_encoder_in, files)
model = CapGenModel(data, w2v_W, vocab_processor)
model.build_model()
model.train()
评论列表
文章目录