feedforward.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:deep-iv 作者: allentran 项目源码 文件源码
def build_instrument_model(self, n_vars, **kwargs):

        targets = TT.vector()
        instrument_vars = TT.matrix()

        instruments = layers.InputLayer((None, n_vars), instrument_vars)
        instruments = layers.DropoutLayer(instruments, p=0.2)

        dense_layer = layers.DenseLayer(instruments, kwargs['dense_size'], nonlinearity=nonlinearities.tanh)
        dense_layer = layers.DropoutLayer(dense_layer, p=0.2)

        for _ in xrange(kwargs['n_dense_layers'] - 1):
            dense_layer = layers.DenseLayer(dense_layer, kwargs['dense_size'], nonlinearity=nonlinearities.tanh)
            dense_layer = layers.DropoutLayer(dense_layer, p=0.5)

        self.instrument_output = layers.DenseLayer(dense_layer, 1, nonlinearity=nonlinearities.linear)
        init_params = layers.get_all_param_values(self.instrument_output)
        prediction = layers.get_output(self.instrument_output, deterministic=False)
        test_prediction = layers.get_output(self.instrument_output, deterministic=True)

        # flexible here, endog variable can be categorical, continuous, etc.
        l2_cost = regularization.regularize_network_params(self.instrument_output, regularization.l2)
        loss = objectives.squared_error(prediction.flatten(), targets.flatten()).mean() + 1e-4 * l2_cost
        loss_total = objectives.squared_error(prediction.flatten(), targets.flatten()).mean()

        params = layers.get_all_params(self.instrument_output, trainable=True)
        param_updates = updates.adadelta(loss, params)

        self._instrument_train_fn = theano.function(
            [
                targets,
                instrument_vars,
            ],
            loss,
            updates=param_updates
        )

        self._instrument_loss_fn = theano.function(
            [
                targets,
                instrument_vars,
            ],
            loss_total
        )

        self._instrument_output_fn = theano.function([instrument_vars], test_prediction)

        return init_params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号