train.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:comprehend 作者: Fenugreek 项目源码 文件源码
def get_trainer(cost, learning_rate=.001, grad_clips=(-1., 1.), logger=logger,
                **kwargs):
    """Return opertation that trains parameters, given cost tensor."""

    opt = tf.train.AdamOptimizer(learning_rate)
    if grad_clips is None: return opt.minimize(cost, **kwargs)

    grads_vars = []
    for grad_var in opt.compute_gradients(cost, **kwargs):
        if grad_var[0] is None:
            if logger is not None:
                logger.info('No gradient for variable {}', grad_var[1].name)
            continue
        grads_vars.append((tf.clip_by_value(grad_var[0], -1., 1.), grad_var[1]))

    return opt.apply_gradients(grads_vars)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号