models.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:causal_bandits 作者: finnhacks42 项目源码 文件源码
def V_short(self,eta):
        sum0 = np.zeros(7,dtype=float)
        sum1 = np.zeros(7,dtype=float)
        for n1,n2 in product(range(self.N1+1),range(self.N2+1)):
             wdo = comb(self.N1,n1,exact=True)*comb(self.N2,n2,exact=True)
             wdox10 = comb(self.N1-1,n1,exact=True)*comb(self.N2,n2,exact=True)
             wdox11 = comb(self.N1-1,n1-1,exact=True)*comb(self.N2,n2,exact=True)
             wdox20 = comb(self.N1,n1,exact=True)*comb(self.N2-1,n2,exact=True)
             wdox21 = comb(self.N1,n1,exact=True)*comb(self.N2-1,n2-1,exact=True)
             w = np.asarray([wdox10,wdox20,wdox11,wdox21,wdo,wdo,wdo])

             pz0,pz1 = self.p_n_given_z(n1,n2)

             counts = [self.N1-n1,self.N2-n2,n1,n2,1,1,1]
             Q = (eta*pz0*counts*(1-self.pZgivenA)+eta*pz1*counts*self.pZgivenA).sum()

             ratio = np.nan_to_num(np.true_divide(pz0*(1-self.pZgivenA)+pz1*self.pZgivenA,Q))

             sum0 += np.asfarray(w*pz0*ratio)
             sum1 += np.asfarray(w*pz1*ratio)
        result = self.pZgivenA*sum1+(1-self.pZgivenA)*sum0
        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号