def main(_):
print("Parameters: ")
for k, v in FLAGS.__flags.items():
print("{} = {}".format(k, v))
if not os.path.exists("./prepro/"):
os.makedirs("./prepro/")
if FLAGS.prepro:
img_feat, tags_idx, a_tags_idx, vocab_processor = data_utils.load_train_data(FLAGS.train_dir, FLAGS.tag_path, FLAGS.prepro_dir, FLAGS.vocab)
else:
img_feat = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "img_feat.dat"), 'rb'))
tags_idx = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "tag_ids.dat"), 'rb'))
a_tags_idx = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "a_tag_ids.dat"), 'rb'))
vocab_processor = VocabularyProcessor.restore(FLAGS.vocab)
img_feat = np.array(img_feat, dtype='float32')/127.5 - 1.
test_tags_idx = data_utils.load_test(FLAGS.test_path, vocab_processor)
print("Image feature shape: {}".format(img_feat.shape))
print("Tags index shape: {}".format(tags_idx.shape))
print("Attribute Tags index shape: {}".format(a_tags_idx.shape))
print("Vocab size: {}".format(len(vocab_processor._reverse_mapping)))
print("Vocab max length: {}".format(vocab_processor.max_document_length))
data = Data(img_feat, tags_idx, a_tags_idx, test_tags_idx, FLAGS.z_dim, vocab_processor)
Model = getattr(sys.modules[__name__], FLAGS.model)
print(Model)
model = Model(data, vocab_processor, FLAGS)
model.build_model()
model.train()
python类Data()的实例源码
def main(_):
print("\nParameters: ")
for k, v in sorted(FLAGS.__flags.items()):
print("{} = {}".format(k, v))
if not os.path.exists("./prepro/"):
os.makedirs("./prepro/")
if FLAGS.eval:
print("Evaluation...")
else:
if FLAGS.prepro:
print ("Start preprocessing data...")
vocab_processor, train_dict = data_utils.load_text_data(train_lab=FLAGS.train_lab,
prepro_train_p=FLAGS.prepro_train, vocab_path=FLAGS.vocab)
print ("Vocabulary size: {}".format(len(vocab_processor._reverse_mapping)))
print ("Start dumping word2vec matrix...")
w2v_W = data_utils.build_w2v_matrix(vocab_processor, FLAGS.w2v_data, FLAGS.vector_file, FLAGS.embedding_dim)
else:
train_dict = cPickle.load(open(FLAGS.prepro_train, 'rb'))
vocab_processor = VocabularyProcessor.restore(FLAGS.vocab)
w2v_W = cPickle.load(open(FLAGS.w2v_data, 'rb'))
print("Start generating training data...")
feats, encoder_in_idx, decoder_in = data_utils.gen_train_data(FLAGS.train_dir, FLAGS.train_lab, train_dict)
print("Start generating validation data...")
v_encoder_in, truth_captions = data_utils.load_valid(FLAGS.valid_dir, FLAGS.valid_lab)
t_encoder_in = None
files = None
if FLAGS.task_dir != None:
t_encoder_in, files = data_utils.load_task(FLAGS.task_dir)
print('feats size: {}, training size: {}'.format(len(feats), len(encoder_in_idx)))
print(encoder_in_idx.shape, decoder_in.shape)
print(v_encoder_in.shape, len(truth_captions))
data = Data(feats, encoder_in_idx, decoder_in, v_encoder_in, truth_captions, t_encoder_in, files)
model = CapGenModel(data, w2v_W, vocab_processor)
model.build_model()
model.train()