def define_network(inputs):
network = lasagne.layers.InputLayer(shape=(None, params.CHANNELS, params.INPUT_SIZE, params.INPUT_SIZE, params.INPUT_SIZE),
input_var=inputs)
network = Conv3DDNNLayer(
network, num_filters=64, filter_size=(5, 5, 5),
nonlinearity=lasagne.nonlinearities.leaky_rectify,
W=HeNormal(gain='relu'))
network = MaxPool3DDNNLayer(network, pool_size=(2, 2, 2))
if params.BATCH_NORMALIZATION:
network = lasagne.layers.batch_norm(network)
network = Conv3DDNNLayer(
network, num_filters=64, filter_size=(5, 5, 5),
nonlinearity=lasagne.nonlinearities.leaky_rectify,
W=HeNormal(gain='relu'))
network = Conv3DDNNLayer(
network, num_filters=96, filter_size=(5, 5, 5),
nonlinearity=lasagne.nonlinearities.leaky_rectify,
W=HeNormal(gain='relu'))
if params.BATCH_NORMALIZATION:
network = lasagne.layers.batch_norm(network)
network = lasagne.layers.DenseLayer(
network,
num_units=420,
nonlinearity=lasagne.nonlinearities.leaky_rectify,
W=HeNormal(gain='relu')
)
network = lasagne.layers.DenseLayer(
network, num_units=params.N_CLASSES,
nonlinearity=lasagne.nonlinearities.softmax)
return network
评论列表
文章目录