def create_architecture(self):
self.vars.sequence_length = tf.placeholder(tf.int64, [1], name="sequence_length")
fc_input = self.get_input_layers()
fc1 = fully_connected(fc_input, num_outputs=self.fc_units_num,
scope=self._name_scope + "/fc1",
)
fc1_reshaped = tf.reshape(fc1, [1, -1, self.fc_units_num])
self.recurrent_cells = self.ru_class(self._recurrent_units_num)
state_c = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.c], name="initial_lstm_state_c")
state_h = tf.placeholder(tf.float32, [1, self.recurrent_cells.state_size.h], name="initial_lstm_state_h")
self.vars.initial_network_state = LSTMStateTuple(state_c, state_h)
rnn_outputs, self.ops.network_state = tf.nn.dynamic_rnn(self.recurrent_cells,
fc1_reshaped,
initial_state=self.vars.initial_network_state,
sequence_length=self.vars.sequence_length,
time_major=False,
scope=self._name_scope)
reshaped_rnn_outputs = tf.reshape(rnn_outputs, [-1, self._recurrent_units_num])
self.reset_state()
self.ops.pi_logits = fully_connected(reshaped_rnn_outputs,
num_outputs=self.actions_num,
scope=self._name_scope + "/fc_pi",
activation_fn=None)
self.ops.pi = tf.nn.softmax(self.ops.pi_logits)
state_value = linear(reshaped_rnn_outputs,
num_outputs=1,
scope=self._name_scope + "/fc_value")
self.ops.v = tf.reshape(state_value, [-1])
if self.multi_frameskip:
frameskip_output_len = self.actions_num
else:
frameskip_output_len = 1
if self.fs_stop_gradient:
reshaped_rnn_outputs = tf.stop_gradient(reshaped_rnn_outputs)
self.ops.frameskip_n = 1 + fully_connected(reshaped_rnn_outputs,
num_outputs=frameskip_output_len,
scope=self._name_scope + "/fc_frameskip_n",
activation_fn=tf.nn.relu,
biases_initializer=tf.constant_initializer(self.fs_n_bias))
frameskip_p = fully_connected(reshaped_rnn_outputs,
num_outputs=frameskip_output_len,
scope=self._name_scope + "/fc_frameskip_p",
activation_fn=tf.nn.sigmoid,
biases_initializer=tf.constant_initializer(self.fs_p_bias))
eps = 1e-20
self.ops.frameskip_p = tf.clip_by_value(frameskip_p, eps, 1 - eps)
if not self.multi_frameskip:
self.ops.frameskip_n = tf.reshape(self.ops.frameskip_n, (-1,))
self.ops.frameskip_p = tf.reshape(self.ops.frameskip_p, (-1,))
self.ops.frameskip_policy = [self.ops.frameskip_n, self.ops.frameskip_p]
评论列表
文章目录