def dist_info_sym(self, obs_var, state_info_vars):
n_batches, n_steps = obs_var.shape[:2]
obs_var = obs_var.reshape((n_batches, n_steps, -1))
if self.state_include_action:
prev_action_var = state_info_vars["prev_action"]
all_input_var = TT.concatenate(
[obs_var, prev_action_var],
axis=2
)
else:
all_input_var = obs_var
if self.feature_network is None:
return dict(
prob=L.get_output(
self.prob_network.output_layer,
{self.l_input: all_input_var}
)
)
else:
flat_input_var = TT.reshape(all_input_var, (-1, self.input_dim))
return dict(
prob=L.get_output(
self.prob_network.output_layer,
{self.l_input: all_input_var, self.feature_network.input_layer: flat_input_var}
)
)
categorical_gru_policy.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录