绕过tf.argmax这是不可区分的

发布于 2021-01-29 17:23:25

我已经为我的神经网络编写了一个自定义损失函数,但是它无法计算任何梯度。我认为这是因为我需要最大值的索引,因此正在使用argmax来获取此索引。

由于argmax不可区分,所以我解决了这个问题,但是我不知道这是怎么可能的。

有人可以帮忙吗?

关注者
0
被浏览
51
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    如果您很酷,

    import tensorflow as tf
    import numpy as np
    
    sess = tf.Session()
    x = tf.placeholder(dtype=tf.float32, shape=(None,))
    beta = tf.placeholder(dtype=tf.float32)
    
    # Pseudo-math for the below
    # y = sum( i * exp(beta * x[i]) ) / sum( exp(beta * x[i]) )
    y = tf.reduce_sum(tf.cumsum(tf.ones_like(x)) * tf.exp(beta * x) / tf.reduce_sum(tf.exp(beta * x))) - 1
    
    print("I can compute the gradient", tf.gradients(y, x))
    
    for run in range(10):
        data = np.random.randn(10)
        print(data.argmax(), sess.run(y, feed_dict={x:data/np.linalg.norm(data), beta:1e2}))
    

    这是一种技巧,可以在低温环境中计算平均值,从而得出概率空间的近似最大值。在这种情况下,低温与温度过高有关beta

    实际上,随着beta逼近无穷大,我的算法将收敛到最大值(假设最大值是唯一的)。不幸的是,在您遇到数值错误并得到之前,beta不能太大NaN,但是有一些技巧可以解决,如果您需要的话,我可以研究一下。

    输出看起来像这样,

    0 2.24459
    9 9.0
    8 8.0
    4 4.0
    4 4.0
    8 8.0
    9 9.0
    6 6.0
    9 8.99995
    1 1.0
    

    因此,您可以看到它在某些地方变得混乱,但通常会得到正确的答案。根据您的算法,这可能很好。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看