feedforward.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:deep-iv 作者: allentran 项目源码 文件源码
def build_treatment_model(self, n_vars, **kwargs):

        input_vars = TT.matrix()
        instrument_vars = TT.matrix()
        targets = TT.vector()

        inputs = layers.InputLayer((None, n_vars), input_vars)
        inputs = layers.DropoutLayer(inputs, p=0.2)

        dense_layer = layers.DenseLayer(inputs, 2 * kwargs['dense_size'], nonlinearity=nonlinearities.rectify)
        dense_layer = layers.batch_norm(dense_layer)
        dense_layer= layers.DropoutLayer(dense_layer, p=0.2)

        for _ in xrange(kwargs['n_dense_layers'] - 1):
            dense_layer = layers.DenseLayer(dense_layer, kwargs['dense_size'], nonlinearity=nonlinearities.rectify)
            dense_layer = layers.batch_norm(dense_layer)

        self.treatment_output = layers.DenseLayer(dense_layer, 1, nonlinearity=nonlinearities.linear)
        init_params = layers.get_all_param_values(self.treatment_output)

        prediction = layers.get_output(self.treatment_output, deterministic=False)
        test_prediction = layers.get_output(self.treatment_output, deterministic=True)

        l2_cost = regularization.regularize_network_params(self.treatment_output, regularization.l2)
        loss = gmm_loss(prediction, targets, instrument_vars) + 1e-4 * l2_cost

        params = layers.get_all_params(self.treatment_output, trainable=True)
        param_updates = updates.adadelta(loss, params)

        self._train_fn = theano.function(
            [
                input_vars,
                targets,
                instrument_vars,
            ],
            loss,
            updates=param_updates
        )

        self._loss_fn = theano.function(
            [
                input_vars,
                targets,
                instrument_vars,
            ],
            loss,
        )

        self._output_fn = theano.function(
            [
                input_vars,
            ],
            test_prediction,
        )

        return init_params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号