def dist_info_sym(self, obs_var, state_info_vars):
n_batches, n_steps = obs_var.shape[:2]
obs_var = obs_var.reshape((n_batches, n_steps, -1))
if self._state_include_action:
prev_action_var = state_info_vars["prev_action"]
all_input_var = TT.concatenate(
[obs_var, prev_action_var],
axis=2
)
else:
all_input_var = obs_var
means, log_stds = L.get_output([self._mean_network.output_layer, self._l_log_std], all_input_var)
return dict(mean=means, log_std=log_stds)
gaussian_gru_policy.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录