def getNextIndex(self):
"""
Returns how many batches/sequences to load from each .data file
"""
target_value = (self.scratch_index+1)*(self.batch_memory*self.batch_size)
idx_target = np.searchsorted(self.num_points,target_value, side='right')
if target_value>self.num_points[-1] or idx_target>=len(self.num_points):
idx_target = idx_target - 2
target_value = self.num_points[idx_target]
self.idxend = self.num_points[idx_target] - self.num_points[idx_target-1]
self.nindex = idx_target
else:
while target_value<=self.num_points[idx_target]:
idx_target = idx_target - 1
self.idxend = target_value - self.num_points[idx_target]
self.nindex = idx_target
评论列表
文章目录