def mlp_ptscorer(inputs, Ddim, N, l2reg, pfx='out', Dinit='glorot_uniform', sum_mode='sum', extra_inp=[]):
""" Element-wise features from the pair fed to an MLP. """
linear = Activation('linear')
if sum_mode == 'absdiff':
absdiff = Lambda(function=lambda x: K.abs(x[0] - x[1]),
output_shape=lambda shape: shape[0])
# model.add_node(name=pfx+'sum', layer=absdiff_merge(model, inputs))
mlp_inputs = absdiff(inputs)
elif sum_mode == 'sum':
outsum = linear(add(inputs))
outmul = linear(multiply(inputs))
mlp_inputs = [outsum, outmul] + extra_inp
def mlp_args(mlp_inputs):
""" return model.add_node() args that are good for mlp_inputs list
of both length 1 and more than 1. """
if isinstance(mlp_inputs, list):
mlp_inputs = concatenate(mlp_inputs)
return mlp_inputs
# Ddim may be either 0 (no hidden layer), scalar (single hidden layer) or
# list (multiple hidden layers)
if Ddim == 0:
mlp_inputs = mlp_args(mlp_inputs)
Ddim = []
elif not isinstance(Ddim, list):
Ddim = [Ddim]
if Ddim:
for i, D in enumerate(Ddim):
mlp_inputs = Dense(int(N*D), activation='tanh', kernel_initializer=Dinit, kernel_regularizer=l2(l2reg))(mlp_args(mlp_inputs))
# model.add_node(name=pfx+'hdn[%d]'%(i,),
# layer=Dense(output_dim=int(N*D), W_regularizer=l2(l2reg), activation='tanh', init=Dinit),
# **mlp_args(mlp_inputs))
# mlp_inputs = [pfx+'hdn[%d]'%(i,)]
outmlp = Dense(1, kernel_regularizer=l2(l2reg))(mlp_inputs)
return outmlp
评论列表
文章目录