def bias_add(x, bias, data_format=None):
def _bias_add(X, data_format):
x, bias = X
from keras.backend import image_data_format, ndim, reshape
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
if ndim(bias) != 1 and ndim(bias) != ndim(x) - 1:
raise ValueError('Unexpected bias dimensions %d, '
'expect to be 1 or %d dimensions'
% (ndim(bias), ndim(x) - 1))
bias_shape = tuple(bias.size())
ndim_x = len(x.size())
ndim_bias = len(bias_shape)
if ndim_x == 5:
if data_format == 'channels_first':
if ndim_bias == 1:
bias = reshape(bias, (1, bias_shape[0], 1, 1, 1))
else:
bias = reshape(bias, (1, bias_shape[3]) + bias_shape[:3])
elif data_format == 'channels_last':
if ndim_bias == 1:
bias = reshape(bias, (1, 1, 1, 1, bias_shape[0]))
else:
bias = reshape(bias, (1,) + bias_shape)
elif ndim_x == 4:
if data_format == 'channels_first':
if ndim_bias == 1:
bias = reshape(bias, (1, bias_shape[0], 1, 1))
else:
bias = reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
elif data_format == 'channels_last':
if ndim_bias == 1:
bias = reshape(bias, (1, 1, 1, bias_shape[0]))
else:
bias = reshape(bias, (1,) + bias_shape)
elif ndim_x == 3:
if data_format == 'channels_first':
if ndim_bias == 1:
bias = reshape(bias, (1, bias_shape[0], 1))
else:
bias = reshape(bias, (1, bias_shape[1], bias_shape[0]))
elif data_format == 'channels_last':
if ndim_bias == 1:
bias = reshape(bias, (1, 1, bias_shape[0]))
else:
bias = reshape(bias, (1,) + bias_shape)
return x.add(bias.expand_as(x))
def _compute_output_shape(X):
return _get_shape(X[0])
return get_op(_bias_add, output_shape=_compute_output_shape, arguments=[data_format])([x, bias])
评论列表
文章目录