def _postprocess_conv2d_output(x, data_format):
"""Transpose and cast the output from conv2d if needed.
Arguments:
x: A tensor.
data_format: string, one of "channels_last", "channels_first".
Returns:
A tensor.
"""
if data_format == 'channels_first':
x = array_ops.transpose(x, (0, 3, 1, 2))
if floatx() == 'float64':
x = math_ops.cast(x, 'float64')
return x
评论列表
文章目录