bluenet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def make_subseparable_kernel(kernel_size, input_channels, filters, separability,
                             kernel_initializer, kernel_regularizer):
  """Make a kernel to do subseparable convolution wiht  `tf.nn.conv2d`.

  Args:
    kernel_size: (height, width) tuple.
    input_channels: Number of input channels.
    filters: Number of output channels.
    separability: Integer denoting separability.
    kernel_initializer: Initializer to use for the kernel.
    kernel_regularizer: Regularizer to use for the kernel.

  Returns:
    A 4D tensor.
  """
  if separability == 1:
    # Non-separable convolution
    return tf.get_variable(
        "kernel",
        kernel_size + (input_channels, filters),
        initializer=kernel_initializer,
        regularizer=kernel_regularizer)

  elif separability == 0 or separability == -1:
    # Separable convolution
    # TODO(rshin): Check initialization is as expected, as these are not 4D.
    depthwise_kernel = tf.get_variable(
        "depthwise_kernel",
        kernel_size + (input_channels,),
        initializer=kernel_initializer,
        regularizer=kernel_regularizer)

    pointwise_kernel = tf.get_variable(
        "pointwise_kernel", (input_channels, filters),
        initializer=kernel_initializer,
        regularizer=kernel_regularizer)

    expanded_depthwise_kernel = tf.transpose(
        tf.scatter_nd(
            indices=tf.tile(
                tf.expand_dims(tf.range(0, input_channels), axis=1), [1, 2]),
            updates=tf.transpose(depthwise_kernel, (2, 0, 1)),
            shape=(input_channels, input_channels) + kernel_size), (2, 3, 0, 1))

    return tf.reshape(
        tf.matmul(
            tf.reshape(expanded_depthwise_kernel, (-1, input_channels)),
            pointwise_kernel), kernel_size + (input_channels, filters))

  elif separability >= 2:
    assert filters % separability == 0, (filters, separability)
    assert input_channels % separability == 0, (filters, separability)

    raise NotImplementedError

  elif separability <= -2:
    separability *= -1
    assert filters % separability == 0, (filters, separability)
    assert input_channels % separability == 0, (filters, separability)

    raise NotImplementedError
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号