tsp_task.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:neural-combinatorial-rl-pytorch 作者: pemami4911 项目源码 文件源码
def __init__(self, dataset_fname=None, train=False, size=50, num_samples=1000000, random_seed=1111):
        super(TSPDataset, self).__init__()
        #start = torch.FloatTensor([[-1], [-1]]) 

        torch.manual_seed(random_seed)

        self.data_set = []
        if not train:
            with open(dataset_fname, 'r') as dset:
                for l in tqdm(dset):
                    inputs, outputs = l.split(' output ')
                    sample = torch.zeros(1, )
                    x = np.array(inputs.split(), dtype=np.float32).reshape([-1, 2]).T
                    #y.append(np.array(outputs.split(), dtype=np.int32)[:-1]) # skip the last one
                    self.data_set.append(x)
        else:
            # randomly sample points uniformly from [0, 1]
            for l in tqdm(range(num_samples)):
                x = torch.FloatTensor(2, size).uniform_(0, 1)
                #x = torch.cat([start, x], 1)
                self.data_set.append(x)

        self.size = len(self.data_set)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号