torch_backend.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ktorch 作者: farizrahman4u 项目源码 文件源码
def bias_add(x, bias, data_format=None):
    def _bias_add(X, data_format):
        x, bias = X
        from keras.backend import image_data_format, ndim, reshape
        if data_format is None:
            data_format = image_data_format()
        if data_format not in {'channels_first', 'channels_last'}:
            raise ValueError('Unknown data_format ' + str(data_format))
        if ndim(bias) != 1 and ndim(bias) != ndim(x) - 1:
            raise ValueError('Unexpected bias dimensions %d, '
                             'expect to be 1 or %d dimensions'
                             % (ndim(bias), ndim(x) - 1))
        bias_shape = tuple(bias.size())
        ndim_x = len(x.size())
        ndim_bias = len(bias_shape)
        if ndim_x == 5:
            if data_format == 'channels_first':
                if ndim_bias == 1:
                    bias = reshape(bias, (1, bias_shape[0], 1, 1, 1))
                else:
                    bias = reshape(bias, (1, bias_shape[3]) + bias_shape[:3])
            elif data_format == 'channels_last':
                if ndim_bias == 1:
                    bias = reshape(bias, (1, 1, 1, 1, bias_shape[0]))
                else:
                    bias = reshape(bias, (1,) + bias_shape)
        elif ndim_x == 4:
            if data_format == 'channels_first':
                if ndim_bias == 1:
                    bias = reshape(bias, (1, bias_shape[0], 1, 1))
                else:
                    bias = reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
            elif data_format == 'channels_last':
                if ndim_bias == 1:
                    bias = reshape(bias, (1, 1, 1, bias_shape[0]))
                else:
                    bias = reshape(bias, (1,) + bias_shape)
        elif ndim_x == 3:
            if data_format == 'channels_first':
                if ndim_bias == 1:
                    bias = reshape(bias, (1, bias_shape[0], 1))
                else:
                    bias = reshape(bias, (1, bias_shape[1], bias_shape[0]))
            elif data_format == 'channels_last':
                if ndim_bias == 1:
                    bias = reshape(bias, (1, 1, bias_shape[0]))
                else:
                    bias = reshape(bias, (1,) + bias_shape)
        return x.add(bias.expand_as(x))

    def _compute_output_shape(X):
        return _get_shape(X[0])

    return get_op(_bias_add, output_shape=_compute_output_shape, arguments=[data_format])([x, bias])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号