bidnn.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:BiDNN 作者: v-v 项目源码 文件源码
def __init__(self, conf):
        self.conf = conf

        if self.conf.act == "linear":
            self.conf.act = linear
        elif self.conf.act == "sigmoid":
            self.conf.act = sigmoid
        elif self.conf.act == "relu":
            self.conf.act = rectify
        elif self.conf.act == "tanh":
            self.conf.act = tanh
        else:
            raise ValueError("Unknown activation function", self.conf.act)

        input_var_first   = T.matrix('inputs1')
        input_var_second  = T.matrix('inputs2')
        target_var        = T.matrix('targets')

        # create network        
        self.autoencoder, encoder_first, encoder_second = self.__create_toplogy__(input_var_first, input_var_second)

        self.out = get_output(self.autoencoder)

        loss = squared_error(self.out, target_var)
        loss = loss.mean()

        params = get_all_params(self.autoencoder, trainable=True)
        updates = nesterov_momentum(loss, params, learning_rate=self.conf.lr, momentum=self.conf.momentum)

        # training function
        self.train_fn = theano.function([input_var_first, input_var_second, target_var], loss, updates=updates)

        # fuction to reconstruct
        test_reconstruction = get_output(self.autoencoder, deterministic=True)
        self.reconstruction_fn = theano.function([input_var_first, input_var_second], test_reconstruction)

        # encoding function
        test_encode = get_output([encoder_first, encoder_second], deterministic=True)
        self.encoding_fn = theano.function([input_var_first, input_var_second], test_encode)

        # utils
        blas = lambda name, ndarray: scipy.linalg.get_blas_funcs((name,), (ndarray,))[0]
        self.blas_nrm2 = blas('nrm2', np.array([], dtype=float))
        self.blas_scal = blas('scal', np.array([], dtype=float))

        # load weights if necessary
        if self.conf.load_model is not None:
            self.load_model()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号