run_train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def run():
    if len(sys.argv) < 3:
        print("** Usage: python3 " + sys.argv[0] + " <<Model Directory>> <<Train Set>> [<<Dev Set>>]")
        sys.exit(1)

    np.random.seed(42)

    model_dir = sys.argv[1]
    model_conf = os.path.join(model_dir, 'model.conf')
    config = Config.load(['./default.conf', model_conf])
    model = create_model(config)
    train_data = load_data(sys.argv[2], config.dictionary, config.grammar, config.max_length)
    if len(sys.argv) > 3:
        dev_data = load_data(sys.argv[3], config.dictionary, config.grammar, config.max_length)
    else:
        dev_data = None
    print("unknown", unknown_tokens)
    try:
        os.mkdir(model_dir)
    except OSError:
        pass
    if not os.path.exists(model_conf):
        config.save(model_conf)

    with tf.Graph().as_default():
        tf.set_random_seed(1234)
        model.build()
        init = tf.global_variables_initializer()

        saver = tf.train.Saver(max_to_keep=config.n_epochs)

        train_eval = Seq2SeqEvaluator(model, config.grammar, train_data, 'train', config.reverse_dictionary, beam_size=config.beam_size, batch_size=config.batch_size)
        dev_eval = Seq2SeqEvaluator(model, config.grammar, dev_data, 'dev', config.reverse_dictionary, beam_size=config.beam_size, batch_size=config.batch_size)
        trainer = Trainer(model, train_data, train_eval, dev_eval, saver,
                          model_dir=model_dir,
                          max_length=config.max_length,
                          batch_size=config.batch_size,
                          n_epochs=config.n_epochs,
                          dropout=config.dropout)

        tfconfig = tf.ConfigProto()
        tfconfig.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        with tf.Session(config=tfconfig) as sess:
            # Run the Op to initialize the variables.
            sess.run(init)
            #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            #sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

            # Fit the model
            best_dev, best_train = trainer.fit(sess)

            print("best train", best_train)
            print("best dev", best_dev)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号