def find_previous(self):
sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
sfiles = glob.glob(sfiles)
sfiles.sort(key=os.path.getmtime)
# Get the snapshot name in TensorFlow
redfiles = []
for stepsize in cfg.TRAIN.STEPSIZE:
redfiles.append(os.path.join(self.output_dir,
cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize+1)))
sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]
nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
nfiles = glob.glob(nfiles)
nfiles.sort(key=os.path.getmtime)
redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
nfiles = [nn for nn in nfiles if nn not in redfiles]
lsf = len(sfiles)
assert len(nfiles) == lsf
return lsf, nfiles, sfiles
评论列表
文章目录