def make_subseparable_kernel(kernel_size, input_channels, filters, separability,
kernel_initializer, kernel_regularizer):
"""Make a kernel to do subseparable convolution wiht `tf.nn.conv2d`.
Args:
kernel_size: (height, width) tuple.
input_channels: Number of input channels.
filters: Number of output channels.
separability: Integer denoting separability.
kernel_initializer: Initializer to use for the kernel.
kernel_regularizer: Regularizer to use for the kernel.
Returns:
A 4D tensor.
"""
if separability == 1:
# Non-separable convolution
return tf.get_variable(
"kernel",
kernel_size + (input_channels, filters),
initializer=kernel_initializer,
regularizer=kernel_regularizer)
elif separability == 0 or separability == -1:
# Separable convolution
# TODO(rshin): Check initialization is as expected, as these are not 4D.
depthwise_kernel = tf.get_variable(
"depthwise_kernel",
kernel_size + (input_channels,),
initializer=kernel_initializer,
regularizer=kernel_regularizer)
pointwise_kernel = tf.get_variable(
"pointwise_kernel", (input_channels, filters),
initializer=kernel_initializer,
regularizer=kernel_regularizer)
expanded_depthwise_kernel = tf.transpose(
tf.scatter_nd(
indices=tf.tile(
tf.expand_dims(tf.range(0, input_channels), axis=1), [1, 2]),
updates=tf.transpose(depthwise_kernel, (2, 0, 1)),
shape=(input_channels, input_channels) + kernel_size), (2, 3, 0, 1))
return tf.reshape(
tf.matmul(
tf.reshape(expanded_depthwise_kernel, (-1, input_channels)),
pointwise_kernel), kernel_size + (input_channels, filters))
elif separability >= 2:
assert filters % separability == 0, (filters, separability)
assert input_channels % separability == 0, (filters, separability)
raise NotImplementedError
elif separability <= -2:
separability *= -1
assert filters % separability == 0, (filters, separability)
assert input_channels % separability == 0, (filters, separability)
raise NotImplementedError
评论列表
文章目录