test_netwrok.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:a3c 作者: hercky 项目源码 文件源码
def rmsprop_updates(grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6):
    updates = OrderedDict()
    # Using theano constant to prevent upcasting of float32
    one = T.constant(1)
    c = 0
    for param, grad in zip(params, grads):
        print c 
        value = param.get_value(borrow=True)
        accu = theano.shared(numpy.zeros(value.shape, dtype=value.dtype),broadcastable=param.broadcastable)
        accu_new = rho * accu + (one - rho) * grad ** 2
        updates[accu] = accu_new
        mid_up = param - (learning_rate * grad / (T.sqrt(accu_new + epsilon)))
        try:
            updates[param] = lasagne.updates.norm_constraint( mid_up , 40 , 0)
        except:
            updates[param] = mid_up
        c+=1
    return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号