resnet.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:luna16 作者: gzuidhof 项目源码 文件源码
def define_updates(output_layer, X, Y):
    output_train = lasagne.layers.get_output(output_layer)
    output_test = lasagne.layers.get_output(output_layer, deterministic=True)

    # set up the loss that we aim to minimize when using cat cross entropy our Y should be ints not one-hot
    loss = lasagne.objectives.categorical_crossentropy(T.clip(output_train,0.000001,0.999999), Y)
    loss = loss.mean()

    acc = T.mean(T.eq(T.argmax(output_train, axis=1), Y), dtype=theano.config.floatX)

    # if using ResNet use L2 regularization
    all_layers = lasagne.layers.get_all_layers(output_layer)
    l2_penalty = lasagne.regularization.regularize_layer_params(all_layers, lasagne.regularization.l2) * P.L2_LAMBDA
    loss = loss + l2_penalty

    # set up loss functions for validation dataset
    test_loss = lasagne.objectives.categorical_crossentropy(T.clip(output_test,0.000001,0.999999), Y)
    test_loss = test_loss.mean()
    test_loss = test_loss + l2_penalty

    test_acc = T.mean(T.eq(T.argmax(output_test, axis=1), Y), dtype=theano.config.floatX)

    # get parameters from network and set up sgd with nesterov momentum to update parameters, l_r is shared var so it can be changed
    l_r = theano.shared(np.array(LR_SCHEDULE[0], dtype=theano.config.floatX))
    params = lasagne.layers.get_all_params(output_layer, trainable=True)
    updates = nesterov_momentum(loss, params, learning_rate=l_r, momentum=P.MOMENTUM)
    #updates = adam(loss, params, learning_rate=l_r)

    prediction_binary = T.argmax(output_train, axis=1)
    test_prediction_binary = T.argmax(output_test, axis=1)

    # set up training and prediction functions
    train_fn = theano.function(inputs=[X,Y], outputs=[loss, l2_penalty, acc, prediction_binary, output_train[:,1]], updates=updates)
    valid_fn = theano.function(inputs=[X,Y], outputs=[test_loss, l2_penalty, test_acc, test_prediction_binary, output_test[:,1]])

    return train_fn, valid_fn, l_r
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号