def build_classification_network(r1):
if not isinstance(r1, lasagne.layers.Layer):
l_in = lasagne.layers.InputLayer((None, glimpse_output_size, recurrent_output_size), r1)
else:
l_in = r1
output = lasagne.layers.DenseLayer(l_in, classification_units,
nonlinearity = nl.softmax,
W = class_weights, b = class_bias)
return output
#input is downsampled batch of images
#output is initial r2, of length glimpse_output_size
评论列表
文章目录