def autoformat_stride_3d(strides):
if isinstance(strides, int):
return [1, strides, strides, strides, 1]
elif isinstance(strides, (tuple, list, tf.TensorShape)):
if len(strides) == 3:
return [1, strides[0], strides[1],strides[2], 1]
elif len(strides) == 5:
assert strides[0] == strides[4] == 1, "Must have strides[0] = strides[4] = 1"
return [strides[0], strides[1], strides[2], strides[3], strides[4]]
else:
raise Exception("strides length error: " + str(len(strides))
+ ", only a length of 3 or 5 is supported.")
else:
raise Exception("strides format error: " + str(type(strides)))
# Auto format kernel for 3d convolution
评论列表
文章目录