def dist_info_sym(self, obs_var, state_info_vars):
n_batches = tf.shape(obs_var)[0]
n_steps = tf.shape(obs_var)[1]
obs_var = tf.reshape(obs_var, tf.pack([n_batches, n_steps, -1]))
if self.state_include_action:
prev_action_var = state_info_vars["prev_action"]
all_input_var = tf.concat(2, [obs_var, prev_action_var])
else:
all_input_var = obs_var
if self.feature_network is None:
means, log_stds = L.get_output(
[self.mean_network.output_layer, self.l_log_std],
{self.l_input: all_input_var}
)
else:
flat_input_var = tf.reshape(all_input_var, (-1, self.input_dim))
means, log_stds = L.get_output(
[self.mean_network.output_layer, self.l_log_std],
{self.l_input: all_input_var, self.feature_network.input_layer: flat_input_var}
)
return dict(mean=means, log_std=log_stds)
评论列表
文章目录