def configure_env_dropout(self, env, sampler_params=None, dropout=0.01, tau=0.15, length_scale=1e-2):
def sampler_factory():
params = env.get_default_sampler_params()
params['n_epochs'] = 50
wreg = length_scale ** 2 * (1 - dropout) / (2. * env.get_train_x().shape[0] * tau)
model = DropoutSampler.model_from_description(env.layers_description, wreg, dropout)
logging.info(f'Reg: {wreg}')
if sampler_params is not None:
params.update(sampler_params)
sampler = DropoutSampler(model=model, **params)
sampler.construct()
return sampler
env.sampler_factory = sampler_factory
评论列表
文章目录