def build_network(self):
state = tf.placeholder(tf.float32, [None, 84, 84, 4])
cnn_1 = slim.conv2d(state, 16, [8,8], stride=4, scope=self.name + '/cnn_1', activation_fn=nn.relu)
cnn_2 = slim.conv2d(cnn_1, 32, [4,4], stride=2, scope=self.name + '/cnn_2', activation_fn=nn.relu)
flatten = slim.flatten(cnn_2)
fcc_1 = slim.fully_connected(flatten, 256, scope=self.name + '/fcc_1', activation_fn=nn.relu)
adv_probas = slim.fully_connected(fcc_1, self.nb_actions, scope=self.name + '/adv_probas', activation_fn=nn.softmax)
value_state = slim.fully_connected(fcc_1, 1, scope=self.name + '/value_state', activation_fn=None)
tf.summary.scalar("model/cnn1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_1')))
tf.summary.scalar("model/cnn2_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_2')))
tf.summary.scalar("model/fcc1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/fcc_1')))
tf.summary.scalar("model/adv_probas_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/adv_probas')))
tf.summary.scalar("model/value_state_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/value_state')))
#Input
self._tf_state = state
#Output
self._tf_adv_probas = adv_probas
self._tf_value_state = value_state
评论列表
文章目录