def _initialize_network(self, img_input_shape, misc_len, output_size, img_input, misc_input=None, **kwargs):
input_layers = []
inputs = [img_input]
# weights_init = lasagne.init.GlorotUniform("relu")
weights_init = lasagne.init.HeNormal("relu")
network = ls.InputLayer(shape=img_input_shape, input_var=img_input)
input_layers.append(network)
network = ls.Conv2DLayer(network, num_filters=32, filter_size=8, nonlinearity=rectify, W=weights_init,
b=lasagne.init.Constant(.1), stride=4)
network = ls.Conv2DLayer(network, num_filters=64, filter_size=4, nonlinearity=rectify, W=weights_init,
b=lasagne.init.Constant(.1), stride=2)
network = ls.Conv2DLayer(network, num_filters=64, filter_size=3, nonlinearity=rectify, W=weights_init,
b=lasagne.init.Constant(.1), stride=1)
if self.misc_state_included:
inputs.append(misc_input)
network = ls.FlattenLayer(network)
misc_input_layer = ls.InputLayer(shape=(None, misc_len), input_var=misc_input)
input_layers.append(misc_input_layer)
if "additional_misc_layer" in kwargs:
misc_input_layer = ls.DenseLayer(misc_input_layer, int(kwargs["additional_misc_layer"]),
nonlinearity=rectify,
W=weights_init, b=lasagne.init.Constant(0.1))
network = ls.ConcatLayer([network, misc_input_layer])
# Duelling here
advanteges_branch = ls.DenseLayer(network, 256, nonlinearity=rectify,
W=weights_init, b=lasagne.init.Constant(.1))
advanteges_branch = ls.DenseLayer(advanteges_branch, output_size, nonlinearity=None,
b=lasagne.init.Constant(.1))
state_value_branch = ls.DenseLayer(network, 256, nonlinearity=rectify,
W=weights_init, b=lasagne.init.Constant(.1))
state_value_branch = ls.DenseLayer(state_value_branch, 1, nonlinearity=None,
b=lasagne.init.Constant(.1))
network = DuellingMergeLayer([advanteges_branch, state_value_branch])
return network, input_layers, inputs
评论列表
文章目录