def set_prior(self, bernoulli_prior=None, density_conditions=None):
"""
density_conditions is the max no of ones in each dimension
[min_row, min_col, max_row, max_col].
zero means unrestricted
"""
if density_conditions is None:
self.density_conditions = np.array([0,0,0,0], dtype=np.int8)
else:
assert len(density_conditions) == 4
self.density_conditions = np.array(density_conditions, dtype=np.int8)
self.bernoulli_prior = bernoulli_prior
if bernoulli_prior is None:
self.logit_bernoulli_prior = 0
else:
self.logit_bernoulli_prior = np.log(bernoulli_prior/(1-bernoulli_prior))
评论列表
文章目录