ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:auDeep 作者: auDeep 项目源码 文件源码
def linear(input: tf.Tensor,
           output_size: int,
           weight_initializer: Optional[Initializer] = None,
           bias_initializer: Optional[Initializer] = None,
           name: str = "linear") -> tf.Tensor:
    """
    Apply a linear transformation to a tensor.

    Parameters
    ----------
    input: tf.Tensor
        The tensor which should be linearly transformed
    output_size: int
        The desired output size of the linear transformation
    weight_initializer: tf.Initializer, optional
        A custom initializer for the weight matrix of the linear transformation
    bias_initializer: tf.Initializer, optional
        A custom initializer for the bias vector of the linear transformation
    name: str, optional
        A name for the operation (default "linear")

    Returns
    -------
    tf.Tensor
        The linearly transformed input tensor
    """
    shape = input.get_shape().as_list()

    with tf.variable_scope(name):
        weights = tf.get_variable(name="weights",
                                  shape=[shape[-1], output_size],
                                  dtype=tf.float32,
                                  initializer=weight_initializer)

        bias = tf.get_variable(name="bias",
                               shape=[output_size],
                               initializer=bias_initializer)

        return tf.matmul(input, weights) + bias
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号