LSTMEncDecAttn.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def setOptimizer(args, EncDecAtt):
    # optimizer???
    if args.optimizer == 'SGD':
        optimizer = chaOpt.SGD(lr=args.lrate)
        sys.stdout.write(
            '# SET Learning %s: initial learning rate: %e\n' %
            (args.optimizer, optimizer.lr))
    elif args.optimizer == 'Adam':
        # assert 0, "Currently Adam is not supported for asynchronous update"
        optimizer = chaOpt.Adam(alpha=args.lrate)
        sys.stdout.write(
            '# SET Learning %s: initial learning rate: %e\n' %
            (args.optimizer, optimizer.alpha))
    elif args.optimizer == 'MomentumSGD':
        optimizer = chaOpt.MomentumSGD(lr=args.lrate)
        sys.stdout.write(
            '# SET Learning %s: initial learning rate: %e\n' %
            (args.optimizer, optimizer.lr))
    elif args.optimizer == 'AdaDelta':
        optimizer = chaOpt.AdaDelta(rho=args.lrate)
        sys.stdout.write(
            '# SET Learning %s: initial learning rate: %e\n' %
            (args.optimizer, optimizer.rho))
    else:
        assert 0, "ERROR"

    optimizer.setup(EncDecAtt.model)  # ???optimizer?????????
    if args.optimizer == 'Adam':
        optimizer.t = 1  # warning?????????hack ???????????

    return optimizer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号