data_loader.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pointer-network-tensorflow 作者: devsisters 项目源码 文件源码
def _maybe_generate_and_save(self, except_list=[]):
    self.data = {}

    for name, num in self.data_num.items():
      if name in except_list:
        tf.logging.info("Skip creating {} because of given except_list {}".format(name, except_list))
        continue
      path = self.get_path(name)

      if not os.path.exists(path):
        tf.logging.info("Creating {} for [{}]".format(path, self.task))

        x = np.zeros([num, self.max_length, 2], dtype=np.float32)
        y = np.zeros([num, self.max_length], dtype=np.int32)

        for idx in trange(num, desc="Create {} data".format(name)):
          n_nodes = self.rng.randint(self.min_length, self.max_length+ 1)
          nodes, res = generate_one_example(n_nodes, self.rng)
          x[idx,:len(nodes)] = nodes
          y[idx,:len(res)] = res

        np.savez(path, x=x, y=y)
        self.data[name] = TSP(x=x, y=y, name=name)
      else:
        tf.logging.info("Skip creating {} for [{}]".format(path, self.task))
        tmp = np.load(path)
        self.data[name] = TSP(x=tmp['x'], y=tmp['y'], name=name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号