def resnet_arg_scope(
weight_decay=0.0001,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
):
batch_norm_params = {
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale,
}
l2_regularizer = layers.l2_regularizer(weight_decay)
arg_scope_layers = arg_scope(
[layers.conv2d, my_layers.preact_conv2d, layers.fully_connected],
weights_initializer=layers.variance_scaling_initializer(),
weights_regularizer=l2_regularizer,
activation_fn=tf.nn.relu)
arg_scope_conv = arg_scope(
[layers.conv2d, my_layers.preact_conv2d],
normalizer_fn=layers.batch_norm,
normalizer_params=batch_norm_params)
with arg_scope_layers, arg_scope_conv as arg_sc:
return arg_sc
评论列表
文章目录