def create_critic(self, name, state_input, action_input, reuse=False):
hidden = state_input
weights = []
with tf.variable_scope(name, reuse=reuse):
for index, n_hidden in enumerate(self.n_hiddens):
if index == 1:
hidden = tf.concat([hidden, action_input], axis=1)
hidden, layer_weights = denselayer("hidden_critic_{}".format(index), hidden, n_hidden,
self.nonlinearity, tf.truncated_normal_initializer())
weights += layer_weights
value, layer_weights = denselayer("value", hidden, 1,
w_initializer=tf.random_uniform_initializer(-3e-3, 3e-3))
value = tf.reshape(value, [-1])
weights += layer_weights
weight_phs = [tf.placeholder(tf.float32, shape=w.get_shape()) for w in weights]
return value, weights, weight_phs
评论列表
文章目录