adda_network.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:adda_mnist64 作者: davidtellez 项目源码 文件源码
def __init__(self):

        # Define inputs
        input_var = T.ftensor4('input_var')  # input images (batchx3x64x64)
        labels_classifier_var = T.ivector('labels_classifier_var')  # labels for images
        labels_domain_var = T.ivector('labels_domain_var')  # labels for domain (1 for source, 0 for target)
        learning_rate = T.fscalar('learning_rate')

        # Define classifier networks
        network_classifier = self.network_classifier(input_var)
        network_discriminator = self.network_discriminator(network_classifier['classifier/pool1'])

        # Define outputs
        prediction_classifier = get_output(network_classifier['classifier/output'])  # prob image classification
        prediction_discriminator = get_output(network_discriminator['discriminator/output'])  # prob image domain (should be 1 for source)

        # Define losses (objectives)
        loss_classifier_only = T.mean(categorical_crossentropy(prediction_classifier, labels_classifier_var) * labels_domain_var)
        loss_discriminator = T.mean(categorical_crossentropy(prediction_discriminator, labels_domain_var))
        loss_classifier = loss_classifier_only - loss_discriminator

        # Define performance
        perf_classifier_only = categorical_accuracy(prediction_classifier, labels_classifier_var).mean()
        perf_discriminator = categorical_accuracy(prediction_discriminator, labels_domain_var).mean()

        # Define params
        params_classifier = lasagne.layers.get_all_params(network_classifier['classifier/output'], trainable=True)
        params_discriminator = lasagne.layers.get_all_params(network_discriminator['discriminator/output'], trainable=True)
        params_discriminator = [param for param in params_discriminator if 'discriminator' in param.name]

        # Define updates
        updates_classifier = lasagne.updates.adam(loss_classifier, params_classifier, learning_rate=learning_rate)
        updates_classifier_only = lasagne.updates.adam(loss_classifier_only, params_classifier, learning_rate=learning_rate)
        updates_discriminator = lasagne.updates.adam(loss_discriminator, params_discriminator, learning_rate=learning_rate)

        # Define training functions
        self.train_fn_classifier = theano.function(
            [input_var, labels_classifier_var, labels_domain_var, learning_rate],
            [loss_classifier, loss_classifier_only, prediction_classifier],
            updates=updates_classifier)
        self.train_fn_classifier_only = theano.function(
            [input_var, labels_classifier_var, labels_domain_var, learning_rate],
            [loss_classifier, loss_classifier_only, prediction_classifier],
            updates=updates_classifier_only)
        self.train_fn_discriminator = theano.function(
            [input_var, labels_domain_var, learning_rate],
            [loss_discriminator, prediction_discriminator],
            updates=updates_discriminator)

        # Define validation functions
        self.valid_fn_classifier = theano.function(
            [input_var, labels_classifier_var],
            [perf_classifier_only, prediction_classifier])

        self.valid_fn_discriminator = theano.function(
            [input_var, labels_domain_var],
            [perf_discriminator, prediction_discriminator])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号