def create_network(self, input, trainable):
if trainable:
wr = slim.l2_regularizer(self.regularization)
else:
wr = None
# the input is stack of black and white frames.
# put the stack in the place of channel (last in tf)
input_t = tf.transpose(input, [0, 2, 3, 1])
net = slim.conv2d(input_t, 8, (7, 7), data_format="NHWC",
activation_fn=tf.nn.relu, stride=3, weights_regularizer=wr, trainable=trainable)
net = slim.max_pool2d(net, 2, 2)
net = slim.conv2d(net, 16, (3, 3), data_format="NHWC",
activation_fn=tf.nn.relu, weights_regularizer=wr, trainable=trainable)
net = slim.max_pool2d(net, 2, 2)
net = slim.flatten(net)
net = slim.fully_connected(net, 256, activation_fn=tf.nn.relu,
weights_regularizer=wr, trainable=trainable)
q_state_action_values = slim.fully_connected(net, self.dim_actions,
activation_fn=None, weights_regularizer=wr, trainable=trainable)
return q_state_action_values
评论列表
文章目录