def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
#print('input x:',x.get_shape().as_list())
#print('input y:',y.get_shape().as_list())
xshape=x.get_shape()
#tile by [1,64,64,1]
tile_shape=tf.stack([1,xshape[1],xshape[2],1])
tile_y=tf.tile(y,tile_shape)
#print('tile y:',tile_y.get_shape().as_list())
return tf.concat([x,tile_y],axis=3)
#x_shapes = x.get_shape()
#y_shapes = y.get_shape()
#return tf.concat([
#x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
评论列表
文章目录