def ranknet(x, relevance_labels, learning_rate, n_hidden, build_vars_fn, score_with_batchnorm_update_fn, score_fn):
n_out = 1
sigma = 1
n_data = tf.shape(x)[0]
print('USING SIGMA = %f' % sigma)
params = build_vars_fn()
predicted_scores, bn_params = score_with_batchnorm_update_fn(x, params)
S_ij = tf.maximum(tf.minimum(1., relevance_labels - tf.transpose(relevance_labels)), -1.)
real_scores = (1/2)*(1+S_ij)
pairwise_predicted_scores = predicted_scores - tf.transpose(predicted_scores)
lambdas = sigma*(1/2)*(1-S_ij) - sigma*tf.divide(1, (1 + tf.exp(sigma*pairwise_predicted_scores)))
non_updating_predicted_scores = score_fn(x, bn_params, params)
non_updating_S_ij = tf.maximum(tf.minimum(1., relevance_labels - tf.transpose(relevance_labels)), -1.)
non_updating_real_scores = (1/2)*(1+non_updating_S_ij)
non_updating_pairwise_predicted_scores = non_updating_predicted_scores - tf.transpose(non_updating_predicted_scores)
non_updating_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=non_updating_pairwise_predicted_scores, labels=non_updating_real_scores))
def get_derivative(W_k):
dsi_dWk = tf.map_fn(lambda x_i: tf.squeeze(tf.gradients(score_fn(tf.expand_dims(x_i, 0), bn_params, params), [W_k])[0]), x)
dsi_dWk_minus_dsj_dWk = tf.expand_dims(dsi_dWk, 1) - tf.expand_dims(dsi_dWk, 0)
desired_lambdas_shape = tf.concat([tf.shape(lambdas), tf.ones([tf.rank(dsi_dWk_minus_dsj_dWk) - tf.rank(lambdas)], dtype=tf.int32)], axis=0)
return tf.reduce_mean(tf.reshape(lambdas, desired_lambdas_shape)*dsi_dWk_minus_dsj_dWk, axis=[0,1])
flat_params = [Wk for pk in params for Wk in pk]
grads = [get_derivative(Wk) for Wk in flat_params]
adam = tf.train.AdamOptimizer(learning_rate=learning_rate)
adam_op = adam.apply_gradients([(tf.reshape(grad, tf.shape(param)), param) for grad, param in zip(grads, flat_params)])
def optimizer(sess, feed_dict):
sess.run(adam_op, feed_dict=feed_dict)
def get_score(sess, feed_dict):
return sess.run(non_updating_predicted_scores, feed_dict=feed_dict)
return non_updating_cost, optimizer, get_score
评论列表
文章目录