def _fill_and_one_pad_stride(stride, n, data_format=DATA_FORMAT_NHWC):
"""Expands the provided stride to size n and pads it with 1s."""
if isinstance(stride, numbers.Integral) or (
isinstance(stride, collections.Iterable) and len(stride) <= n):
if data_format == DATA_FORMAT_NHWC:
return (1,) + _fill_shape(stride, n) + (1,)
elif data_format == DATA_FORMAT_NCHW:
return (1, 1,) + _fill_shape(stride, n)
else:
raise ValueError("Invalid data_format {:s}. Allowed formats "
"{:s}".format(data_format, SUPPORTED_DATA_FORMATS))
elif isinstance(stride, collections.Iterable) and len(stride) == n + 2:
return stride
else:
raise base.IncompatibleShapeError(
"stride is {} ({}), must be either a positive integer or an iterable of"
" positive integers of size {}".format(stride, type(stride), n))
评论列表
文章目录