def main(_):
print("Parameters: ")
for k, v in FLAGS.__flags.items():
print("{} = {}".format(k, v))
if not os.path.exists("./prepro/"):
os.makedirs("./prepro/")
if FLAGS.prepro:
img_feat, tags_idx, a_tags_idx, vocab_processor = data_utils.load_train_data(FLAGS.train_dir, FLAGS.tag_path, FLAGS.prepro_dir, FLAGS.vocab)
else:
img_feat = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "img_feat.dat"), 'rb'))
tags_idx = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "tag_ids.dat"), 'rb'))
a_tags_idx = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "a_tag_ids.dat"), 'rb'))
vocab_processor = VocabularyProcessor.restore(FLAGS.vocab)
img_feat = np.array(img_feat, dtype='float32')/127.5 - 1.
test_tags_idx = data_utils.load_test(FLAGS.test_path, vocab_processor)
print("Image feature shape: {}".format(img_feat.shape))
print("Tags index shape: {}".format(tags_idx.shape))
print("Attribute Tags index shape: {}".format(a_tags_idx.shape))
print("Vocab size: {}".format(len(vocab_processor._reverse_mapping)))
print("Vocab max length: {}".format(vocab_processor.max_document_length))
data = Data(img_feat, tags_idx, a_tags_idx, test_tags_idx, FLAGS.z_dim, vocab_processor)
Model = getattr(sys.modules[__name__], FLAGS.model)
print(Model)
model = Model(data, vocab_processor, FLAGS)
model.build_model()
model.train()
评论列表
文章目录