USVM.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:uplift 作者: mustafaseisa 项目源码 文件源码
def __init__(self, treatment, control):

        assert [type(treatment), type(control)] == [ndarray, ndarray] # numpy arrays only

        yt, Xt = hsplit(treatment, [1])
        yc, Xc = hsplit(control, [1])

        self._nt, mt = Xt.shape
        self._nc, mc = Xc.shape
        self._n = self._nc + self._nt # n is number of datum across both groups

        assert min(mt, mc, self._nt, self._nc) >= 1 and self._n >= 3 # data shouldn't be trivial
        assert mt == mc # same number of features in treatment and control

        self._m = mt # store number of features

        assert unique(yt).all() in [-1,1] and unique(yc).all() in [-1,1] # labels are binary

        tPlusIndex = where(yt.flatten() == 1.0)[0] # index for positive in treatment
        self._ntplus = len(tPlusIndex) # number of such points (length of index)
        tMinusIndex = delete(range(self._nt), tPlusIndex) # index for negative in treatment
        self._ntminus = self._nt - self._ntplus # number of such points

        self._Dtplus = Xt[tPlusIndex] # positive treatment datum
        self._Dtminus = Xt[tMinusIndex] # negative treatment datum

        cPlusIndex = where(yc.flatten() == 1.0)[0] # index for positive in control
        self._ncplus = len(cPlusIndex) # number of such points (length of index)
        cMinusIndex = delete(range(self._nc), cPlusIndex) # index for negative in control
        self._ncminus = self._nc - self._ncplus # number of such points

        self._Dcplus = Xc[cPlusIndex] # positive treatment datum
        self._Dcminus = Xc[cMinusIndex] # negative treatment datum

        # model parameters

        self.__optimized = False # indicator for whether otpimization routine was performed
        options['show_progress'] = False # supress optimization output

        self.w = None # hyperplane slope
        self.b1 = None # treatment group intercept
        self.b2 = None # control group intercept
        self.threshold = None # thresholding predictor function

        print("Successfully initialized.")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号