def dist_info(self, obs, state_infos=None):
if state_infos is None or len(state_infos) == 0:
return self._f_dist_info(obs)
if self._f_dist_info_givens is None:
# compile function
obs_var = self._mean_network.input_var
latent_keys = ["latent_%d" % idx for idx in range(self._n_latent_layers)]
latent_vars = [TT.matrix("latent_%d" % idx) for idx in range(self._n_latent_layers)]
latent_dict = dict(list(zip(latent_keys, latent_vars)))
self._f_dist_info_givens = ext.compile_function(
inputs=[obs_var] + latent_vars,
outputs=self.dist_info_sym(obs_var, latent_dict),
)
latent_vals = []
for idx in range(self._n_latent_layers):
latent_vals.append(state_infos["latent_%d" % idx])
return self._f_dist_info_givens(*[obs] + latent_vals)
stochastic_gaussian_mlp_policy.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录