def load_last_checkpoint(checkpoints_path):
checkpoints_pattern = os.path.join(
checkpoints_path, SaverPlugin.last_pattern.format('*', '*')
)
checkpoint_paths = natsorted(glob(checkpoints_pattern))
if len(checkpoint_paths) > 0:
checkpoint_path = checkpoint_paths[-1]
checkpoint_name = os.path.basename(checkpoint_path)
match = re.match(
SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'),
checkpoint_name
)
epoch = int(match.group(1))
iteration = int(match.group(2))
return (torch.load(checkpoint_path), epoch, iteration)
else:
return None
评论列表
文章目录