def _preprocess_conv2d_input(x, data_format):
"""Transpose and cast the input before the conv2d.
Arguments:
x: input tensor.
data_format: string, one of 'channels_last', 'channels_first'.
Returns:
A tensor.
"""
if dtype(x) == 'float64':
x = math_ops.cast(x, 'float32')
if data_format == 'channels_first':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
x = array_ops.transpose(x, (0, 2, 3, 1))
return x
评论列表
文章目录