def __init__(self,data_root,labels_file):
self.data_files_path=glob(data_root+"*val.pth")
self.model_num=len(self.data_files_path)
self.label_file_path=labels_file
self.data=t.zeros(100,1999*self.model_num)
for i in range(self.model_num):
self.data[:,i*1999:i*1999+1999]=t.sigmoid(t.load(self.data_files_path[i]).float()[:100])
print self.data.size()
评论列表
文章目录