def __init__(self, model, **kwargs):
super(Seq2SeqTrainerPyTorch, self).__init__()
self.steps = 0
self.gpu = bool(kwargs.get('gpu', True))
optim = kwargs.get('optim', 'adam')
eta = float(kwargs.get('eta', 0.01))
mom = float(kwargs.get('mom', 0.9))
self.clip = float(kwargs.get('clip', 5))
if optim == 'adadelta':
self.optimizer = torch.optim.Adadelta(model.parameters(), lr=eta)
elif optim == 'adam':
self.optimizer = torch.optim.Adam(model.parameters(), lr=eta)
elif optim == 'rmsprop':
self.optimizer = torch.optim.RMSprop(model.parameters(), lr=eta)
else:
self.optimizer = torch.optim.SGD(model.parameters(), lr=eta, momentum=mom)
self.model = model
self._input = model.make_input
self.crit = model.create_loss()
if self.gpu:
self.model = torch.nn.DataParallel(model).cuda()
self.crit.cuda()
评论列表
文章目录