layers.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def merge(tensors_list, mode, axis=1, name='merge', outputs_collections=None, **kwargs):
    """
    Merge op

    Args:
        tensor_list: A list `Tensors` to merge
        mode: str, available modes are
            ['concat', 'elemwise_sum', 'elemwise_mul', 'sum',
                'mean', 'prod', 'max', 'min', 'and', 'or']
        name: a optional scope/name of the layer
        outputs_collections: The collections to which the outputs are added.

    Returns:
        A `Tensor` representing the results of the repetition operation.

    Raises:
        ValueError: If 'kernel_size' is not a 2-D list
    """
    assert len(tensors_list) > 1, "Merge required 2 or more tensors."

    with tf.name_scope(name):
        tensors = [l for l in tensors_list]
        if mode == 'concat':
            output = tf.concat(tensors, axis=axis)
        elif mode == 'elemwise_sum':
            output = tensors[0]
            for i in range(1, len(tensors)):
                output = tf.add(output, tensors[i])
        elif mode == 'elemwise_mul':
            output = tensors[0]
            for i in range(1, len(tensors)):
                output = tf.multiply(output, tensors[i])
        elif mode == 'sum':
            output = tf.reduce_sum(tf.concat(tensors, axis=axis), axis=axis)
        elif mode == 'mean':
            output = tf.reduce_mean(tf.concat(tensors, axis=axis), axis=axis)
        elif mode == 'prod':
            output = tf.reduce_prod(tf.concat(tensors, axis=axis), axis=axis)
        elif mode == 'max':
            output = tf.reduce_max(tf.concat(tensors, axis=axis), axis=axis)
        elif mode == 'min':
            output = tf.reduce_min(tf.concat(tensors, axis=axis), axis=axis)
        elif mode == 'and':
            output = tf.reduce_all(tf.concat(tensors, axis=axis), axis=axis)
        elif mode == 'or':
            output = tf.reduce_any(tf.concat(tensors, axis=axis), axis=axis)
        else:
            raise Exception("Unknown merge mode", str(mode))
        return _collect_named_outputs(outputs_collections, name, output)

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号