sklearn_wrapper.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer_sklearn 作者: corochann 项目源码 文件源码
def __init__(self,
                 predictor=None,
                 lossfun=softmax_cross_entropy.softmax_cross_entropy,
                 accfun=accuracy.accuracy,
                 device=-1,
                 **sk_params
                 ):
        """

        :param predictor (~chainer.links.Chain): 
        :param lossfun: loss function
        :param accfun: accuracy function. When `None` is set, accuracy is not 
        calculated during the training and `loassfun` is used for `score`.
        :param device (int): GPU device id. -1 indicates to use CPU.
        :param sk_params (dict): dict of parameters. This is used for 
        `GridSearchCV` and `RandomizedSearchCV` internally. 
        """
        super(SklearnBaseWrapper, self).__init__()
        if predictor is None:
            # Temporal counter measure to pass `check_estimator`,
            # sklearn need to support default constructor
            # TODO: Should dynamically asign n_out, instead of using magic parameter.
            predictor = chainer.links.Linear(None, self._default_n_out)
        if isinstance(predictor, chainer.Link):
            # print('[DEBUG] predictor instance')
            with self.init_scope():
                self.predictor = predictor
            self.predictor_constructor = predictor.__class__
        elif is_function(predictor) or issubclass(predictor, chainer.Link):
            # print('[DEBUG] predictor is constructor')
            self.predictor_constructor = predictor
        else:
            print("[ERROR] predictor should be either Chain class instance or"
                  "function which returns Chain class instance")
            assert False

        self.lossfun = lossfun
        self.accfun = accfun
        self.compute_accuracy = accfun is not None
        self.y = None
        self.loss = None
        self.accuracy = None
        self.inputs = None

        # Ensure initialization, necessary for GridSearch
        self.device = -1
        if hasattr(self, 'predictor'):
            self.predictor.to_cpu()
        self.update_device(device)

        self.sk_params = sk_params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号