def linear(input: tf.Tensor,
output_size: int,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
name: str = "linear") -> tf.Tensor:
"""
Apply a linear transformation to a tensor.
Parameters
----------
input: tf.Tensor
The tensor which should be linearly transformed
output_size: int
The desired output size of the linear transformation
weight_initializer: tf.Initializer, optional
A custom initializer for the weight matrix of the linear transformation
bias_initializer: tf.Initializer, optional
A custom initializer for the bias vector of the linear transformation
name: str, optional
A name for the operation (default "linear")
Returns
-------
tf.Tensor
The linearly transformed input tensor
"""
shape = input.get_shape().as_list()
with tf.variable_scope(name):
weights = tf.get_variable(name="weights",
shape=[shape[-1], output_size],
dtype=tf.float32,
initializer=weight_initializer)
bias = tf.get_variable(name="bias",
shape=[output_size],
initializer=bias_initializer)
return tf.matmul(input, weights) + bias
评论列表
文章目录