def make_dqn_body(input_layer, trainable=True):
end_points = {}
net = layers.conv2d(inputs=input_layer,
num_outputs=16,
kernel_size=[8, 8],
stride=[4, 4],
activation_fn=tf.nn.relu,
padding="same",
scope="conv1",
trainable=trainable)
end_points['conv1'] = net
net = layers.conv2d(inputs=net,
num_outputs=32,
kernel_size=[4, 4],
stride=[2, 2],
activation_fn=tf.nn.relu,
padding="same",
scope="conv2",
trainable=trainable)
end_points['conv2'] = net
out = layers.flatten(net)
end_points['conv2_flatten'] = out
return out, end_points
评论列表
文章目录