models.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:learning-rank-public 作者: andreweskeclarke 项目源码 文件源码
def ranknet(x, relevance_labels, learning_rate, n_hidden, build_vars_fn, score_with_batchnorm_update_fn, score_fn):
    n_out = 1
    sigma = 1
    n_data = tf.shape(x)[0]

    print('USING SIGMA = %f' % sigma)
    params = build_vars_fn()
    predicted_scores, bn_params = score_with_batchnorm_update_fn(x, params)
    S_ij = tf.maximum(tf.minimum(1., relevance_labels - tf.transpose(relevance_labels)), -1.)
    real_scores = (1/2)*(1+S_ij)
    pairwise_predicted_scores = predicted_scores - tf.transpose(predicted_scores)
    lambdas = sigma*(1/2)*(1-S_ij) - sigma*tf.divide(1, (1 + tf.exp(sigma*pairwise_predicted_scores)))

    non_updating_predicted_scores = score_fn(x, bn_params, params)
    non_updating_S_ij = tf.maximum(tf.minimum(1., relevance_labels - tf.transpose(relevance_labels)), -1.)
    non_updating_real_scores = (1/2)*(1+non_updating_S_ij)
    non_updating_pairwise_predicted_scores = non_updating_predicted_scores - tf.transpose(non_updating_predicted_scores)
    non_updating_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=non_updating_pairwise_predicted_scores, labels=non_updating_real_scores))

    def get_derivative(W_k):
        dsi_dWk = tf.map_fn(lambda x_i: tf.squeeze(tf.gradients(score_fn(tf.expand_dims(x_i, 0), bn_params, params), [W_k])[0]), x)
        dsi_dWk_minus_dsj_dWk = tf.expand_dims(dsi_dWk, 1) - tf.expand_dims(dsi_dWk, 0)
        desired_lambdas_shape = tf.concat([tf.shape(lambdas), tf.ones([tf.rank(dsi_dWk_minus_dsj_dWk) - tf.rank(lambdas)], dtype=tf.int32)], axis=0)
        return tf.reduce_mean(tf.reshape(lambdas, desired_lambdas_shape)*dsi_dWk_minus_dsj_dWk, axis=[0,1])

    flat_params = [Wk for pk in params for Wk in pk]
    grads = [get_derivative(Wk) for Wk in flat_params]
    adam = tf.train.AdamOptimizer(learning_rate=learning_rate)
    adam_op = adam.apply_gradients([(tf.reshape(grad, tf.shape(param)), param) for grad, param in zip(grads, flat_params)])

    def optimizer(sess, feed_dict):
        sess.run(adam_op, feed_dict=feed_dict)

    def get_score(sess, feed_dict):
        return sess.run(non_updating_predicted_scores, feed_dict=feed_dict)

    return non_updating_cost, optimizer, get_score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号