def V_short(self,eta):
sum0 = np.zeros(7,dtype=float)
sum1 = np.zeros(7,dtype=float)
for n1,n2 in product(range(self.N1+1),range(self.N2+1)):
wdo = comb(self.N1,n1,exact=True)*comb(self.N2,n2,exact=True)
wdox10 = comb(self.N1-1,n1,exact=True)*comb(self.N2,n2,exact=True)
wdox11 = comb(self.N1-1,n1-1,exact=True)*comb(self.N2,n2,exact=True)
wdox20 = comb(self.N1,n1,exact=True)*comb(self.N2-1,n2,exact=True)
wdox21 = comb(self.N1,n1,exact=True)*comb(self.N2-1,n2-1,exact=True)
w = np.asarray([wdox10,wdox20,wdox11,wdox21,wdo,wdo,wdo])
pz0,pz1 = self.p_n_given_z(n1,n2)
counts = [self.N1-n1,self.N2-n2,n1,n2,1,1,1]
Q = (eta*pz0*counts*(1-self.pZgivenA)+eta*pz1*counts*self.pZgivenA).sum()
ratio = np.nan_to_num(np.true_divide(pz0*(1-self.pZgivenA)+pz1*self.pZgivenA,Q))
sum0 += np.asfarray(w*pz0*ratio)
sum1 += np.asfarray(w*pz1*ratio)
result = self.pZgivenA*sum1+(1-self.pZgivenA)*sum0
return result
评论列表
文章目录