def broadcast_mult(inputs1, inputs2):
""""""
inputs1_shape = tf.shape(inputs1)
inputs_size = inputs1.get_shape().as_list()[-1]
inputs2_shape = tf.shape(inputs2)
inputs1 = tf.transpose(inputs1, [0,2,1])
inputs2 = tf.transpose(inputs2, [0,2,1])
inputs1 = tf.reshape(inputs1, tf.pack([-1,inputs1_shape[1],1]))
inputs2 = tf.reshape(inputs2, tf.pack([-1,1,inputs2_shape[1]]))
inputs = inputs1 * inputs2
inputs = tf.reshape(inputs, tf.pack([inputs1_shape[0], inputs1_shape[2], inputs1_shape[1], inputs2_shape[1]]))
inputs = tf.transpose(inputs, [0,2,3,1])
inputs.set_shape([tf.Dimension(None)]*3 + [tf.Dimension(inputs_size)])
return inputs
#***************************************************************
评论列表
文章目录