def broadcast_add(inputs1, inputs2):
""""""
inputs1_shape = tf.shape(inputs1)
inputs_size = inputs1.get_shape().as_list()[-1]
inputs2_shape = tf.shape(inputs2)
inputs1 = tf.transpose(inputs1, [0,2,1])
inputs2 = tf.transpose(inputs2, [0,2,1])
inputs1 = tf.reshape(inputs1, tf.pack([-1,inputs1_shape[1],1]))
inputs2 = tf.reshape(inputs2, tf.pack([-1,1,inputs2_shape[1]]))
inputs = inputs1 + inputs2
inputs = tf.reshape(inputs, [inputs1_shape[0], inputs1_shape[2], inputs1_shape[1], inputs2_shape[1]])
inputs = tf.transpose(inputs, [0,2,3,1])
inputs.set_shape([tf.Dimension(None)]*3 + [tf.Dimension(inputs_size)])
return inputs
#===============================================================
评论列表
文章目录