def call(self, data, mask=None):
tmp1 = tf.strided_slice(data,[0,0,0,0],[1024,tf.to_int32(data.get_shape()[1]),tf.to_int32(data.get_shape()[2]),tf.to_int32(data.get_shape()[3])],[1,1,2,2])
tmp2 = tf.strided_slice(data,[0,0,0,1],[1024,tf.to_int32(data.get_shape()[1]),tf.to_int32(data.get_shape()[2]),tf.to_int32(data.get_shape()[3])],[1,1,2,2])
tmp3 = tf.strided_slice(data,[0,0,1,0],[1024,tf.to_int32(data.get_shape()[1]),tf.to_int32(data.get_shape()[2]),tf.to_int32(data.get_shape()[3])],[1,1,2,2])
tmp4 = tf.strided_slice(data,[0,0,1,1],[1024,tf.to_int32(data.get_shape()[1]),tf.to_int32(data.get_shape()[2]),tf.to_int32(data.get_shape()[3])],[1,1,2,2])
if int(tf.__version__[0]) < 1:
return tf.concat(1,[tmp1, tmp2, tmp3, tmp4])
else:
return tf.concat([tmp1, tmp2, tmp3, tmp4],1)
评论列表
文章目录