train_cnn.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:CNN_denoise 作者: weedwind 项目源码 文件源码
def get_loader(chunk_list):
   data = []
   label = []
   for f in chunk_list:
       print ('Loading data from %s' %f)
       with h5py.File(f, 'r') as hf:
          data.append(np.asarray(hf['data']))
          label.append(np.asarray(hf['label']))
   data = torch.FloatTensor(np.concatenate(data, axis = 0))
   label = torch.FloatTensor(np.concatenate(label, axis = 0))
   print ('Total %d frames loaded' %data.size(0))

   dset_train = TensorDataset(data, label)
   loader_train = DataLoader(dset_train, batch_size = 256, shuffle = True, num_workers = 10, pin_memory = False)
   return loader_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号