def _postprocess_conv3d_output(x, data_format):
"""Transpose and cast the output from conv3d if needed.
Arguments:
x: A tensor.
data_format: string, one of "channels_last", "channels_first".
Returns:
A tensor.
"""
if data_format == 'channels_first':
x = array_ops.transpose(x, (0, 4, 1, 2, 3))
if floatx() == 'float64':
x = math_ops.cast(x, 'float64')
return x
评论列表
文章目录