def test(test_args):
start = time.time()
with open(os.path.join(test_args.save_dir, 'config.pkl'), 'rb') as f:
args = pickle.load(f)
args.save_dir = test_args.save_dir
data_loader = TextLoader(args, train=False)
test_data = data_loader.read_dataset(test_args.test_file)
print(args.save_dir)
print("Unit: " + args.unit)
print("Composition: " + args.composition)
args.word_vocab_size = data_loader.word_vocab_size
if args.unit != "word":
args.subword_vocab_size = data_loader.subword_vocab_size
# Statistics of words
print("Word vocab size: " + str(data_loader.word_vocab_size))
# Statistics of sub units
if args.unit != "word":
print("Subword vocab size: " + str(data_loader.subword_vocab_size))
if args.composition == "bi-lstm":
if args.unit == "char":
args.bilstm_num_steps = data_loader.max_word_len
print("Max word length:", data_loader.max_word_len)
elif args.unit == "char-ngram":
args.bilstm_num_steps = data_loader.max_ngram_per_word
print("Max ngrams per word:", data_loader.max_ngram_per_word)
elif args.unit == "morpheme" or args.unit == "oracle":
args.bilstm_num_steps = data_loader.max_morph_per_word
print("Max morphemes per word", data_loader.max_morph_per_word)
if args.unit == "word":
lm_model = WordModel
elif args.composition == "addition":
lm_model = AdditiveModel
elif args.composition == "bi-lstm":
lm_model = BiLSTMModel
else:
sys.exit("Unknown unit or composition.")
print("Begin testing...")
with tf.Graph().as_default(), tf.Session() as sess:
with tf.variable_scope("model"):
mtest = lm_model(args, is_training=False, is_testing=True)
# save only the last model
saver = tf.train.Saver(tf.all_variables(), max_to_keep=1)
tf.initialize_all_variables().run()
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
test_perplexity = run_epoch(sess, mtest, test_data, data_loader, tf.no_op())
print("Test Perplexity: %.3f" % test_perplexity)
print("Test time: %.0f\n" % (time.time() - start))
print("\n")
评论列表
文章目录