lstm.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tensorflow-statereader 作者: sebastianGehrmann 项目源码 文件源码
def main(_):
    if not FLAGS.data_path:
        raise ValueError("Must set --data_path to PTB data directory")

    raw_data = reader.ptb_raw_data(FLAGS.data_path, True)
    train_data, valid_data, _ = raw_data

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-FLAGS.init_scale,
                                                    FLAGS.init_scale)

        with tf.name_scope("Train"):
            train_input = PTBInput(data=train_data, name="TrainInput")
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                m = PTBModel(is_training=True, input_=train_input)
            tf.summary.scalar("Training Loss", m.cost)
            tf.summary.scalar("Learning Rate", m.lr)

        with tf.name_scope("Train_states"):
            train_input = PTBInput(data=train_data, name="TrainInput")
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                mstates = PTBModel(is_training=False, input_=train_input)
            tf.summary.scalar("Training Loss", mstates.cost)

        with tf.name_scope("Valid"):
            valid_input = PTBInput(data=valid_data, name="ValidInput")
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                mvalid = PTBModel(is_training=False, input_=valid_input)
            tf.summary.scalar("Validation Loss", mvalid.cost)


        sv = tf.train.Supervisor(logdir=FLAGS.save_path)
        with sv.managed_session() as session:
            if FLAGS.load_path:
                sv.saver.restore(session, tf.train.latest_checkpoint(FLAGS.load_path))
            else:
                for i in range(FLAGS.max_max_epoch):
                    lr_decay = FLAGS.lr_decay ** max(i + 1 - FLAGS.max_epoch, 0.0)
                    m.assign_lr(session, FLAGS.learning_rate * lr_decay)

                    print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
                    train_perplexity, stat = run_epoch(session, m, eval_op=m.train_op,
                                                       verbose=True)
                    print(stat.shape)
                    print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
                    valid_perplexity, stat = run_epoch(session, mvalid)
                    print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
            # run and store the states on training set
            train_perplexity, stat = run_epoch(session, mstates, eval_op=m.train_op,
                                               verbose=True)
            f = h5py.File("states.h5", "w")
            stat = np.reshape(stat, (-1, mstates.size))
            f["states1"] = stat
            f.close()

            if FLAGS.save_path:
                print("Saving model to %s." % FLAGS.save_path)
                sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号