def feed_forward_net(input, output, hidden_layers=[64, 64], activations='relu',
dropout_rate=0., l2=0., constrain_norm=False):
'''
Helper function for building a Keras feed forward network.
input: Keras Input object appropriate for the data. e.g. input=Input(shape=(20,))
output: Function representing final layer for the network that maps from the last
hidden layer to output.
e.g. if output = Dense(10, activation='softmax') if we're doing 10 class
classification or output = Dense(1, activation='linear') if we're doing
regression.
'''
state = input
if isinstance(activations, str):
activations = [activations] * len(hidden_layers)
for h, a in zip(hidden_layers, activations):
if l2 > 0.:
w_reg = keras.regularizers.l2(l2)
else:
w_reg = None
const = maxnorm(2) if constrain_norm else None
state = Dense(h, activation=a, kernel_regularizer=w_reg, kernel_constraint=const)(state)
if dropout_rate > 0.:
state = Dropout(dropout_rate)(state)
return output(state)
评论列表
文章目录