x2y_yz2x_xy2p_ssl_cifar10.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:triple-gan 作者: zhenxuan00 项目源码 文件源码
def robust_adam(loss, params, learning_rate, beta1=0.9, beta2=0.999, epsilon=1.0e-8):
    # Convert NaNs to zeros.
    def clear_nan(x):
        return T.switch(T.isnan(x), np.float32(0.0), x)

    new = OrderedDict()
    pg = zip(params, lasagne.updates.get_or_compute_grads(loss, params))
    t = theano.shared(lasagne.utils.floatX(0.))

    new[t] = t + 1.0
    coef = learning_rate * T.sqrt(1.0 - beta2**new[t]) / (1.0 - beta1**new[t])
    for p, g in pg:
        value = p.get_value(borrow=True)
        m = theano.shared(np.zeros(value.shape, dtype=value.dtype), broadcastable=p.broadcastable)
        v = theano.shared(np.zeros(value.shape, dtype=value.dtype), broadcastable=p.broadcastable)
        new[m] = clear_nan(beta1 * m + (1.0 - beta1) * g)
        new[v] = clear_nan(beta2 * v + (1.0 - beta2) * g**2)
        new[p] = clear_nan(p - coef * new[m] / (T.sqrt(new[v]) + epsilon))

    return new
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号