mcmc_sampler.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:bnn-analysis 作者: myshkov 项目源码 文件源码
def __init__(self, loss_fn=None, initial_position=None, test_model=None, batch_size=None, burn_in=0,
                 step_sizes=.0001, step_probabilities=1., **kwargs):
        """
        Creates a new MCMC_sampler object.

        :param loss_fn: Target loss function without regularisaion terms
        :param initial_position: Initial network weights as a 2-d array of shape [number of chains, number of weights]
        :param test_model: The model used on the test data. Default=None
        :param batch_size: Batch size used for stochastic sampling methods. Default=None
        :param burn_in: Number of burn-in samples. Default=0
        :param step_sizes: Step size or a list of step sizes. Default=.0001
        :param step_probabilities: Probabilities to choose a step from step_sizes, must sum to 1. Default=1
        """

        super().__init__(**kwargs)
        self.loss_fn = loss_fn
        self.test_model = test_model

        self.initial_position = np.asarray(initial_position, dtype=np.float32)
        self.position_shape = self.initial_position.shape
        self.position_size = self.initial_position.shape[1]  # total number of parameters of one network

        # data and parameter shapes
        self.chains_num = self.initial_position.shape[0]  # number of chains to run in parallel
        self.batch_size = batch_size if batch_size is not None else self.train_size
        self.batch_x_shape = (self.batch_size, self.input_dim)
        self.batch_y_shape = (self.batch_size, self.output_dim)

        # common parameters
        self.step_sizes = np.atleast_1d(np.asarray(step_sizes, dtype=np.float32))
        self.step_probabilities = np.atleast_1d(np.asarray(step_probabilities, dtype=np.float32))
        self.burn_in = burn_in
        self.step_multiplier = np.ones(shape=(self.chains_num,), dtype=np.float32)

        # monitor acceptance rate for reporting
        self.avg_acceptance_rate = np.ones(shape=(self.chains_num,), dtype=np.float32)
        self.avg_acceptance_rate_lambda = 0.99
        self._has_burned_in = False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号