ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:auDeep 作者: auDeep 项目源码 文件源码
def conv2d(input: tf.Tensor,
           output_dim: int,
           kernel_width: int = 5,
           kernel_height: int = 5,
           horizontal_stride: int = 2,
           vertical_stride: int = 2,
           weight_initializer: Optional[Initializer] = None,
           bias_initializer: Optional[Initializer] = None,
           name: str = "conv2d"):
    """
    Apply a 2D-convolution to a tensor.

    Parameters
    ----------
    input: tf.Tensor
        The tensor to which the convolution should be applied. Must be of shape [batch_size, height, width, channels]
    output_dim: int
        The number of convolutional filters
    kernel_width: int, optional
        The width of the convolutional filters (default 5)
    kernel_height: int, optional
        The height of the convolutional filters (default 5)
    horizontal_stride: int, optional
        The horizontal stride of the convolutional filters (default 2)
    vertical_stride: int, optional
        The vertical stride of the convolutional filters (default 2)
    weight_initializer: tf.Initializer, optional
        A custom initializer for the weight matrices of the filters
    bias_initializer: tf.Initializer, optional
        A custom initializer for the bias vectors of the filters
    name: str, optional
        A name for the operation (default "conv2d")

    Returns
    -------
    tf.Tensor
        The result of applying a 2D-convolution to the input tensor.
    """
    shape = input.get_shape().as_list()

    with tf.variable_scope(name):
        weights = tf.get_variable(name="weights",
                                  shape=[kernel_height, kernel_width, shape[-1], output_dim],
                                  initializer=weight_initializer)

        bias = tf.get_variable(name="bias",
                               shape=[output_dim],
                               initializer=bias_initializer)

        conv = tf.nn.conv2d(input,
                            filter=weights,
                            strides=[1, vertical_stride, horizontal_stride, 1],
                            padding='SAME')

        conv = tf.nn.bias_add(conv, bias)

        return conv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号