data_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def test_reading_without_targets(self):
    num_epochs = 50
    data_provider = make_parallel_data_provider(
        data_sources_source=[self.source_file.name],
        data_sources_target=None,
        num_epochs=num_epochs,
        shuffle=True)

    item_keys = list(data_provider.list_items())
    item_values = data_provider.get(item_keys)
    items_dict = dict(zip(item_keys, item_values))

    self.assertEqual(set(item_keys), set(["source_tokens", "source_len"]))

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        item_dicts_ = [sess.run(items_dict) for _ in range(num_epochs * 3)]

    for item_dict in item_dicts_:
      self.assertEqual(item_dict["source_len"], 2)
      item_dict["source_tokens"] = np.char.decode(
          item_dict["source_tokens"].astype("S"), "utf-8")
      self.assertEqual(item_dict["source_tokens"][-1], "SEQUENCE_END")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号