def test_reading_without_targets(self):
num_epochs = 50
data_provider = make_parallel_data_provider(
data_sources_source=[self.source_file.name],
data_sources_target=None,
num_epochs=num_epochs,
shuffle=True)
item_keys = list(data_provider.list_items())
item_values = data_provider.get(item_keys)
items_dict = dict(zip(item_keys, item_values))
self.assertEqual(set(item_keys), set(["source_tokens", "source_len"]))
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
item_dicts_ = [sess.run(items_dict) for _ in range(num_epochs * 3)]
for item_dict in item_dicts_:
self.assertEqual(item_dict["source_len"], 2)
item_dict["source_tokens"] = np.char.decode(
item_dict["source_tokens"].astype("S"), "utf-8")
self.assertEqual(item_dict["source_tokens"][-1], "SEQUENCE_END")
评论列表
文章目录