main.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Conditional-GAN 作者: m516825 项目源码 文件源码
def main(_):

    print("Parameters: ")
    for k, v in FLAGS.__flags.items():
        print("{} = {}".format(k, v))

    if not os.path.exists("./prepro/"):
        os.makedirs("./prepro/")

    if FLAGS.prepro:
        img_feat, tags_idx, a_tags_idx, vocab_processor = data_utils.load_train_data(FLAGS.train_dir, FLAGS.tag_path, FLAGS.prepro_dir, FLAGS.vocab)    
    else:
        img_feat = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "img_feat.dat"), 'rb'))
        tags_idx = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "tag_ids.dat"), 'rb'))
        a_tags_idx = cPickle.load(open(os.path.join(FLAGS.prepro_dir, "a_tag_ids.dat"), 'rb'))
        vocab_processor = VocabularyProcessor.restore(FLAGS.vocab)
    img_feat = np.array(img_feat, dtype='float32')/127.5 - 1.
    test_tags_idx = data_utils.load_test(FLAGS.test_path, vocab_processor)

    print("Image feature shape: {}".format(img_feat.shape))
    print("Tags index shape: {}".format(tags_idx.shape))
    print("Attribute Tags index shape: {}".format(a_tags_idx.shape))
    print("Vocab size: {}".format(len(vocab_processor._reverse_mapping)))
    print("Vocab max length: {}".format(vocab_processor.max_document_length))

    data = Data(img_feat, tags_idx, a_tags_idx, test_tags_idx, FLAGS.z_dim, vocab_processor)

    Model = getattr(sys.modules[__name__], FLAGS.model) 
    print(Model)

    model = Model(data, vocab_processor, FLAGS)

    model.build_model()

    model.train()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号