nn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:nmt 作者: Playinf 项目源码 文件源码
def linear(inputs, size, bias, concat=False, dtype=None, scope=None):
    if not isinstance(size, (list, tuple)):
        raise ValueError("size argument must be (input_size, output_size)")

    input_size, output_size = size

    if not isinstance(input_size, (list, tuple)):
        input_size = [input_size]

    if not isinstance(inputs, (list, tuple)):
        inputs = [inputs]

    if len(inputs) != len(input_size):
        raise RuntimeError("unmatched elements found: inputs and input_size")

    results = []

    with variable_scope(scope):
        if concat:
            input_size = sum(input_size)
            inputs = theano.tensor.concatenate(inputs, -1)

            shape = [input_size, output_size]
            matrix = get_variable("matrix", shape, dtype=dtype)
            results.append(theano.dot(inputs, matrix))
        else:
            for i in range(len(input_size)):
                shape = [input_size[i], output_size]
                name = "matrix_%d" % i
                matrix = get_variable(name, shape, dtype=dtype)
                results.append(theano.dot(inputs[i], matrix))

        if bias:
            shape = [output_size]
            bias = get_variable("bias", shape, dtype=dtype)
            results.append(bias)

    if len(results) == 1:
        return results[0]

    return reduce(theano.tensor.add, results)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号