def setOptimizer(args, EncDecAtt):
# optimizer???
if args.optimizer == 'SGD':
optimizer = chaOpt.SGD(lr=args.lrate)
sys.stdout.write(
'# SET Learning %s: initial learning rate: %e\n' %
(args.optimizer, optimizer.lr))
elif args.optimizer == 'Adam':
# assert 0, "Currently Adam is not supported for asynchronous update"
optimizer = chaOpt.Adam(alpha=args.lrate)
sys.stdout.write(
'# SET Learning %s: initial learning rate: %e\n' %
(args.optimizer, optimizer.alpha))
elif args.optimizer == 'MomentumSGD':
optimizer = chaOpt.MomentumSGD(lr=args.lrate)
sys.stdout.write(
'# SET Learning %s: initial learning rate: %e\n' %
(args.optimizer, optimizer.lr))
elif args.optimizer == 'AdaDelta':
optimizer = chaOpt.AdaDelta(rho=args.lrate)
sys.stdout.write(
'# SET Learning %s: initial learning rate: %e\n' %
(args.optimizer, optimizer.rho))
else:
assert 0, "ERROR"
optimizer.setup(EncDecAtt.model) # ???optimizer?????????
if args.optimizer == 'Adam':
optimizer.t = 1 # warning?????????hack ???????????
return optimizer
评论列表
文章目录