reinforcement.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:deep-murasaki 作者: lazydroid 项目源码 文件源码
def get_update(Ws_s, bs_s):
    x, fx = train.get_model(Ws_s, bs_s)

    # Ground truth (who won)
    y = T.vector('y')

    # Compute loss (just log likelihood of a sigmoid fit)
    y_pred = sigmoid(fx)
    loss = -( y * T.log(y_pred) + (1 - y) * T.log(1 - y_pred)).mean()

    # Metrics on the number of correctly predicted ones
    frac_correct = ((fx > 0) * y + (fx < 0) * (1 - y)).mean()

    # Updates
    learning_rate_s = T.scalar(dtype=theano.config.floatX)
    momentum_s = T.scalar(dtype=theano.config.floatX)
    updates = train.nesterov_updates(loss, Ws_s + bs_s, learning_rate_s, momentum_s)

    f_update = theano.function(
        inputs=[x, y, learning_rate_s, momentum_s],
        outputs=[loss, frac_correct],
        updates=updates,
        )

    return f_update
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号