def init_optimizer(self, state_dict=None):
"""Initialize an optimizer for the free parameters of the network.
Args:
state_dict: network parameters
"""
if self.args.fix_embeddings:
for p in self.network.embedding.parameters():
p.requires_grad = False
parameters = [p for p in self.network.parameters() if p.requires_grad]
if self.args.optimizer == 'sgd':
self.optimizer = optim.SGD(parameters, self.args.learning_rate,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'adamax':
self.optimizer = optim.Adamax(parameters,
weight_decay=self.args.weight_decay)
else:
raise RuntimeError('Unsupported optimizer: %s' %
self.args.optimizer)
# --------------------------------------------------------------------------
# Learning
# --------------------------------------------------------------------------
评论列表
文章目录