def define_updates(output_layer, X, Y):
output_train = lasagne.layers.get_output(output_layer)
output_test = lasagne.layers.get_output(output_layer, deterministic=True)
# set up the loss that we aim to minimize when using cat cross entropy our Y should be ints not one-hot
loss = lasagne.objectives.categorical_crossentropy(T.clip(output_train,0.000001,0.999999), Y)
loss = loss.mean()
acc = T.mean(T.eq(T.argmax(output_train, axis=1), Y), dtype=theano.config.floatX)
# if using ResNet use L2 regularization
all_layers = lasagne.layers.get_all_layers(output_layer)
l2_penalty = lasagne.regularization.regularize_layer_params(all_layers, lasagne.regularization.l2) * P.L2_LAMBDA
loss = loss + l2_penalty
# set up loss functions for validation dataset
test_loss = lasagne.objectives.categorical_crossentropy(T.clip(output_test,0.000001,0.999999), Y)
test_loss = test_loss.mean()
test_loss = test_loss + l2_penalty
test_acc = T.mean(T.eq(T.argmax(output_test, axis=1), Y), dtype=theano.config.floatX)
# get parameters from network and set up sgd with nesterov momentum to update parameters, l_r is shared var so it can be changed
l_r = theano.shared(np.array(LR_SCHEDULE[0], dtype=theano.config.floatX))
params = lasagne.layers.get_all_params(output_layer, trainable=True)
updates = nesterov_momentum(loss, params, learning_rate=l_r, momentum=P.MOMENTUM)
#updates = adam(loss, params, learning_rate=l_r)
prediction_binary = T.argmax(output_train, axis=1)
test_prediction_binary = T.argmax(output_test, axis=1)
# set up training and prediction functions
train_fn = theano.function(inputs=[X,Y], outputs=[loss, l2_penalty, acc, prediction_binary, output_train[:,1]], updates=updates)
valid_fn = theano.function(inputs=[X,Y], outputs=[test_loss, l2_penalty, test_acc, test_prediction_binary, output_test[:,1]])
return train_fn, valid_fn, l_r
评论列表
文章目录