def autoformat_kernel_2d(strides):
if isinstance(strides, int):
return [1, strides, strides, 1]
elif isinstance(strides, (tuple, list, tf.TensorShape)):
if len(strides) == 2:
return [1, strides[0], strides[1], 1]
elif len(strides) == 4:
return [strides[0], strides[1], strides[2], strides[3]]
else:
raise Exception("strides length error: " + str(len(strides))
+ ", only a length of 2 or 4 is supported.")
else:
raise Exception("strides format error: " + str(type(strides)))
# Auto format filter size
# Output shape: (rows, cols, input_depth, out_depth)
评论列表
文章目录