def dist_info_sym(self, obs_var, state_info_vars):
n_batches, n_steps = obs_var.shape[:2]
obs_var = obs_var.reshape((n_batches, n_steps, -1))
if self._state_include_action:
prev_action_var = state_info_vars["prev_action"]
all_input_var = TT.concatenate(
[obs_var, prev_action_var],
axis=2
)
else:
all_input_var = obs_var
means, log_stds = L.get_output(
[self._mean_network.output_layer, self._l_log_std], all_input_var)
return dict(mean=means, log_std=log_stds)
评论列表
文章目录