def conditional_bilinear_classifier(self, inputs1, inputs2, n_classes, probs, add_bias1=True, add_bias2=True):
""""""
input_shape = tf.shape(inputs1)
batch_size = input_shape[0]
bucket_size = input_shape[1]
input_size = inputs1.get_shape().as_list()[-1]
input_shape_to_set = [tf.Dimension(None), tf.Dimension(None), input_size+1]
output_shape = tf.pack([batch_size, bucket_size, n_classes, bucket_size])
if len(probs.get_shape().as_list()) == 2:
probs = tf.to_float(tf.one_hot(tf.to_int64(probs), bucket_size, 1, 0))
else:
probs = tf.stop_gradient(probs)
if self.moving_params is None:
keep_prob = self.mlp_keep_prob
else:
keep_prob = 1
if isinstance(keep_prob, tf.Tensor) or keep_prob < 1:
noise_shape = tf.pack([batch_size, 1, input_size])
inputs1 = tf.nn.dropout(inputs1, keep_prob, noise_shape=noise_shape)
inputs2 = tf.nn.dropout(inputs2, keep_prob, noise_shape=noise_shape)
inputs1 = tf.concat(2, [inputs1, tf.ones(tf.pack([batch_size, bucket_size, 1]))])
inputs1.set_shape(input_shape_to_set)
inputs2 = tf.concat(2, [inputs2, tf.ones(tf.pack([batch_size, bucket_size, 1]))])
inputs2.set_shape(input_shape_to_set)
bilin = linalg.bilinear(inputs1, inputs2,
n_classes,
add_bias1=add_bias1,
add_bias2=add_bias2,
initializer=tf.zeros_initializer,
moving_params=self.moving_params)
weighted_bilin = tf.batch_matmul(bilin, tf.expand_dims(probs, 3))
return weighted_bilin, bilin
#=============================================================
评论列表
文章目录