def __init__(self, loss_fn=None, initial_position=None, test_model=None, batch_size=None, burn_in=0,
step_sizes=.0001, step_probabilities=1., **kwargs):
"""
Creates a new MCMC_sampler object.
:param loss_fn: Target loss function without regularisaion terms
:param initial_position: Initial network weights as a 2-d array of shape [number of chains, number of weights]
:param test_model: The model used on the test data. Default=None
:param batch_size: Batch size used for stochastic sampling methods. Default=None
:param burn_in: Number of burn-in samples. Default=0
:param step_sizes: Step size or a list of step sizes. Default=.0001
:param step_probabilities: Probabilities to choose a step from step_sizes, must sum to 1. Default=1
"""
super().__init__(**kwargs)
self.loss_fn = loss_fn
self.test_model = test_model
self.initial_position = np.asarray(initial_position, dtype=np.float32)
self.position_shape = self.initial_position.shape
self.position_size = self.initial_position.shape[1] # total number of parameters of one network
# data and parameter shapes
self.chains_num = self.initial_position.shape[0] # number of chains to run in parallel
self.batch_size = batch_size if batch_size is not None else self.train_size
self.batch_x_shape = (self.batch_size, self.input_dim)
self.batch_y_shape = (self.batch_size, self.output_dim)
# common parameters
self.step_sizes = np.atleast_1d(np.asarray(step_sizes, dtype=np.float32))
self.step_probabilities = np.atleast_1d(np.asarray(step_probabilities, dtype=np.float32))
self.burn_in = burn_in
self.step_multiplier = np.ones(shape=(self.chains_num,), dtype=np.float32)
# monitor acceptance rate for reporting
self.avg_acceptance_rate = np.ones(shape=(self.chains_num,), dtype=np.float32)
self.avg_acceptance_rate_lambda = 0.99
self._has_burned_in = False
评论列表
文章目录