def test_with_counts(self):
vocab_list = ["Hello", ".", "?"]
vocab_counts = [100, 200, 300]
vocab_file = create_temporary_vocab_file(vocab_list,
vocab_counts)
vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
vocabulary.create_vocabulary_lookup_table(vocab_file.name)
self.assertEqual(vocab_size, 6)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
ids = vocab_to_id_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
ids = sess.run(ids)
self.assertAllEqual(ids, [0, 1, 2, 3, 3])
words = id_to_vocab_table.lookup(
tf.convert_to_tensor(
[0, 1, 2, 3], dtype=tf.int64))
words = sess.run(words)
self.assertAllEqual(
np.char.decode(words.astype("S"), "utf-8"),
["Hello", ".", "?", "UNK"])
counts = word_to_count_table.lookup(
tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
counts = sess.run(counts)
self.assertAllEqual(counts, [100, 200, 300, -1, -1])
评论列表
文章目录