def residual(inputs,
depth,
stride,
activate_before_residual,
residual_mask=None,
scope=None):
with tf.variable_scope(scope, 'residual', [inputs]):
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
preact = slim.batch_norm(inputs, scope='preact')
if activate_before_residual:
shortcut = preact
else:
shortcut = inputs
if residual_mask is not None:
# Max-pooling trick only works correctly when stride is 1.
# We assume that stride=2 happens in the first layer where
# residual_mask is None.
assert stride == 1
diluted_residual_mask = slim.max_pool2d(
residual_mask, [3, 3], stride=1, padding='SAME')
else:
diluted_residual_mask = None
flops = 0
conv_output, current_flops = flopsometer.conv2d(
preact,
depth,
3,
stride=stride,
padding='SAME',
output_mask=diluted_residual_mask,
scope='conv1')
flops += current_flops
conv_output, current_flops = flopsometer.conv2d(
conv_output,
depth,
3,
stride=1,
padding='SAME',
activation_fn=None,
normalizer_fn=None,
output_mask=residual_mask,
scope='conv2')
flops += current_flops
if depth_in != depth:
shortcut = slim.avg_pool2d(shortcut, stride, stride, padding='VALID')
value = (depth - depth_in) // 2
shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0], [value, value]])
if residual_mask is not None:
conv_output *= residual_mask
outputs = shortcut + conv_output
return outputs, flops
评论列表
文章目录