def __init__(self, expt_dir='experiment', loss=NLLLoss(), batch_size=64,
random_seed=None,
checkpoint_every=100, print_every=100):
self._trainer = "Simple Trainer"
self.random_seed = random_seed
if random_seed is not None:
random.seed(random_seed)
torch.manual_seed(random_seed)
self.loss = loss
self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size)
self.optimizer = None
self.checkpoint_every = checkpoint_every
self.print_every = print_every
if not os.path.isabs(expt_dir):
expt_dir = os.path.join(os.getcwd(), expt_dir)
self.expt_dir = expt_dir
if not os.path.exists(self.expt_dir):
os.makedirs(self.expt_dir)
self.batch_size = batch_size
self.logger = logging.getLogger(__name__)
评论列表
文章目录