learners.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:edm2016 作者: Knewton 项目源码 文件源码
def __init__(self, correct, student_ids=None, item_ids=None, student_idx=None,
                 item_idx=None, is_held_out=None, num_students=None, num_items=None,
                 **bn_learner_kwargs):
        """
        :param np.ndarray[bool] correct: a 1D array of correctness values
        :param np.ndarray|None student_ids: student identifiers for each interaction; if no student
            indices provided, sort order of these ids determines theta indices.
        :param np.ndarray|None item_ids: item identifiers for each interaction; if no item indices
            are provided, sort order of these ids determines item indices.
        :param np.ndarray[int]|None student_idx: a 1D array mapping `correct` to student index
        :param np.ndarray[int]|None item_idx: a 1D array mapping `correct` to item index
        :param np.ndarray[bool] is_held_out: a 1D array indicating whether the interaction should be
            held out from training (if not all zeros, a held_out test node will be added to learner)
        :param int|None num_students: optional number of students. Default is one plus
            the maximum index.
        :param int|None num_items: optional number of items. Default is one plus
            the maximum index.
        :param bn_learner_kwargs: arguments to be passed on to the BayesNetLearner init
        """
        # convert pandas Series to np.ndarray and check argument dimensions
        correct = np.asarray_chkfinite(correct, dtype=bool)
        student_ids, student_idx = check_and_set_idx(student_ids, student_idx, 'student')
        item_ids, item_idx = check_and_set_idx(item_ids, item_idx, 'item')

        if len(correct) != len(student_idx) or len(correct) != len(item_idx):
            raise ValueError("number of elements in correct ({}), student_idx ({}), and item_idx"
                             "({}) must be the same".format(len(correct), len(student_idx),
                                                            len(item_idx)))
        if is_held_out is not None and (
                len(is_held_out) != len(correct) or is_held_out.dtype != bool):
            raise ValueError("held_out ({}) must be None or an array of bools the same length as "
                             "correct ({})".format(len(is_held_out), len(correct)))

        self.num_students = set_or_check_min(num_students, np.max(student_idx) + 1, 'num_students')
        self.num_items = set_or_check_min(num_items, np.max(item_idx) + 1, 'num_items')

        theta_node = DefaultGaussianNode(THETAS_KEY, self.num_students, ids=student_ids)
        offset_node = DefaultGaussianNode(OFFSET_COEFFS_KEY, self.num_items, ids=item_ids)
        nodes = [theta_node, offset_node]

        # add response nodes (train/test if there is held-out data; else just the train set)
        if is_held_out is not None and np.sum(is_held_out):
            if np.sum(is_held_out) == len(is_held_out):
                raise ValueError("some interactions must be not held out")
            is_held_out = np.asarray_chkfinite(is_held_out, dtype=bool)
            node_names = (TRAIN_RESPONSES_KEY, TEST_RESPONSES_KEY)
            response_idxs = (np.logical_not(is_held_out), is_held_out)
        else:
            node_names = (TRAIN_RESPONSES_KEY,)
            response_idxs = (np.ones_like(correct, dtype=bool),)
        for node_name, response_idx in zip(node_names, response_idxs):
            cpd = OnePOCPD(item_idx=item_idx[response_idx], theta_idx=student_idx[response_idx],
                           num_thetas=self.num_students, num_items=self.num_items)
            param_nodes = {THETAS_KEY: theta_node, OFFSET_COEFFS_KEY: offset_node}
            nodes.append(Node(name=node_name, data=correct[response_idx], cpd=cpd,
                              solver_pars=SolverPars(learn=False), param_nodes=param_nodes,
                              held_out=(node_name == TEST_RESPONSES_KEY)))

        # store leaf nodes for learning
        super(OnePOLearner, self).__init__(nodes=nodes, **bn_learner_kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号