def __init__(self, state_shape, n_hidden, summary=True):
super(CriticNetwork, self).__init__()
self.state_shape = state_shape
self.n_hidden = n_hidden
with tf.variable_scope("critic"):
self.states = tf.placeholder("float", [None] + self.state_shape, name="states")
self.r = tf.placeholder(tf.float32, [None], name="r")
L1 = tf.contrib.layers.fully_connected(
inputs=self.states,
num_outputs=self.n_hidden,
activation_fn=tf.tanh,
weights_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.02),
biases_initializer=tf.zeros_initializer(),
scope="L1")
self.value = tf.reshape(linear(L1, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.loss = tf.reduce_sum(tf.square(self.value - self.r))
self.summary_loss = self.loss
self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
评论列表
文章目录