def _initialize_network(self, img_input_shape, misc_len, output_size, img_input, misc_input=None, **kwargs):
input_layers = []
inputs = [img_input]
# weights_init = lasagne.init.GlorotUniform("relu")
weights_init = lasagne.init.HeNormal("relu")
network = ls.InputLayer(shape=img_input_shape, input_var=img_input)
input_layers.append(network)
network = ls.Conv2DLayer(network, num_filters=32, filter_size=8, nonlinearity=rectify, W=weights_init,
b=lasagne.init.Constant(0.1), stride=4)
network = ls.Conv2DLayer(network, num_filters=64, filter_size=4, nonlinearity=rectify, W=weights_init,
b=lasagne.init.Constant(0.1), stride=2)
network = ls.Conv2DLayer(network, num_filters=64, filter_size=3, nonlinearity=rectify, W=weights_init,
b=lasagne.init.Constant(0.1), stride=1)
network = ls.FlattenLayer(network)
if self.misc_state_included:
layers_for_merge = []
health_inputs = 4
units_per_health_input = 100
for i in range(health_inputs):
oh_input = lasagne.utils.one_hot(misc_input[:, i] - 1, units_per_health_input)
health_input_layer = ls.InputLayer(shape=(None, units_per_health_input), input_var=oh_input)
inputs.append(oh_input)
input_layers.append(health_input_layer)
layers_for_merge.append(health_input_layer)
time_inputs = 4
# TODO set this somewhere else cause it depends on skiprate and timeout ....
units_pertime_input = 525
for i in range(health_inputs,health_inputs+time_inputs):
oh_input = lasagne.utils.one_hot(misc_input[:, i] - 1, units_pertime_input)
time_input_layer = ls.InputLayer(shape=(None, units_pertime_input), input_var=oh_input)
inputs.append(oh_input)
input_layers.append(time_input_layer)
layers_for_merge.append(time_input_layer)
other_misc_input = misc_input[:, health_inputs+time_inputs:]
other_misc_shape = (None, misc_len - health_inputs-time_inputs)
other_misc_input_layer = ls.InputLayer(shape=other_misc_shape,
input_var=other_misc_input)
input_layers.append(other_misc_input_layer)
layers_for_merge.append(other_misc_input_layer)
inputs.append(other_misc_input)
layers_for_merge.append(network)
network = ls.ConcatLayer(layers_for_merge)
network = ls.DenseLayer(network, 512, nonlinearity=rectify,
W=weights_init, b=lasagne.init.Constant(0.1))
network = ls.DenseLayer(network, output_size, nonlinearity=None, b=lasagne.init.Constant(.1))
return network, input_layers, inputs
评论列表
文章目录