train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lstm_gan 作者: vangaa 项目源码 文件源码
def main():
    args = utils.get_args()
    dataset = utils.load_dataset(os.path.join(args.data_path, DATASET_FILE))
    index2word, word2index = utils.load_dicts(os.path.join(args.data_path, VOCABULARY_FILE))

    print("Use dataset with {} sentences".format(dataset.shape[0]))

    batch_size = args.batch_size
    noise_size = args.noise_size
    with tf.Graph().as_default(), tf.Session() as session:   
        lstm_gan = LSTMGAN(
            SENTENCE_SIZE,
            VOCABULARY_SIZE,
            word2index[SENTENCE_START_TOKEN],
            hidden_size_gen = args.hid_gen,
            hidden_size_disc = args.hid_disc,
            input_noise_size = noise_size,
            batch_size = batch_size,
            dropout = args.dropout,
            lr = args.lr,
            grad_cap = args.grad_clip
        )

        session.run(tf.initialize_all_variables())

        if args.save_model or args.load_model:
            saver = tf.train.Saver()

        if args.load_model:
            try:
                saver.restore(session, utils.SAVER_FILE)
            except ValueError:
                print("Cant find model file")
                sys.exit(1)
        while True:
            offset = 0.
            for dataset_part in utils.iterate_over_dataset(dataset, batch_size*args.disc_count):
                print("Start train discriminator wih offset {}...".format(offset))
                for ind, batch in enumerate(utils.iterate_over_dataset(dataset_part, batch_size)):
                    noise = np.random.random(size=(batch_size, noise_size))
                    cost = lstm_gan.train_disc_on_batch(session, noise, batch)
                    print("Processed {} sentences with train cost = {}".format((ind+1)*batch_size, cost))

                print("Start train generator...")
                for ind in range(args.gen_count):
                    noise = np.random.random(size=(batch_size, noise_size))
                    cost = lstm_gan.train_gen_on_batch(session, noise)
                    if args.gen_sent:
                        sent = lstm_gan.generate_sent(session, np.random.random(size=(noise_size, )))
                        print(' '.join(index2word[i] for i in sent))
                    print("Processed {} noise inputs with train cost {}".format((ind+1)*batch_size, cost))

                offset += batch_size*args.disc_count
                if args.save_model:
                    saver.save(sess, utils.SAVER_FILE)
                    print("Model saved")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号