def get_update(Ws_s, bs_s):
x, fx = train.get_model(Ws_s, bs_s)
# Ground truth (who won)
y = T.vector('y')
# Compute loss (just log likelihood of a sigmoid fit)
y_pred = sigmoid(fx)
loss = -( y * T.log(y_pred) + (1 - y) * T.log(1 - y_pred)).mean()
# Metrics on the number of correctly predicted ones
frac_correct = ((fx > 0) * y + (fx < 0) * (1 - y)).mean()
# Updates
learning_rate_s = T.scalar(dtype=theano.config.floatX)
momentum_s = T.scalar(dtype=theano.config.floatX)
updates = train.nesterov_updates(loss, Ws_s + bs_s, learning_rate_s, momentum_s)
f_update = theano.function(
inputs=[x, y, learning_rate_s, momentum_s],
outputs=[loss, frac_correct],
updates=updates,
)
return f_update
评论列表
文章目录