nn.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:XMUNMT 作者: XMUNLP 项目源码 文件源码
def linear(inputs, output_size, bias, concat=False, dtype=None, scope=None):
    """
    Linear layer

    Args:
        inputs: A Tensor or a list of Tensors with shape [batch, input_size]
        output_size: An integer specify the output size
        bias: a boolean value indicate whether to use bias term
        concat: a boolean value indicate whether to concatenate all inputs
        dtype: an instance of tf.DType, the default value is ``tf.float32''
        scope: the scope of this layer, the default value is ``linear''

    Returns:
         a Tensor with shape [batch, output_size]

    Raises:
        RuntimeError: raises ``RuntimeError'' when input sizes do not
                      compatible with each other
    """

    with tf.variable_scope(scope, default_name="linear", values=[inputs]):
        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        input_size = [item.get_shape()[-1].value for item in inputs]

        if len(inputs) != len(input_size):
            raise RuntimeError("inputs and input_size unmatched!")

        output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
                                 axis=0)
        # Flatten to 2D
        inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs]

        results = []

        if concat:
            input_size = sum(input_size)
            inputs = tf.concat(inputs, 1)

            shape = [input_size, output_size]
            matrix = tf.get_variable("matrix", shape, dtype=dtype)
            results.append(tf.matmul(inputs, matrix))
        else:
            for i in range(len(input_size)):
                shape = [input_size[i], output_size]
                name = "matrix_%d" % i
                matrix = tf.get_variable(name, shape, dtype=dtype)
                results.append(tf.matmul(inputs[i], matrix))

        output = tf.add_n(results)

        if bias:
            shape = [output_size]
            bias = tf.get_variable("bias", shape, dtype=dtype)
            output = tf.nn.bias_add(output, bias)

        output = tf.reshape(output, output_shape)

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号