test_sgld.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:chemblnet 作者: jaak-s 项目源码 文件源码
def test_psgld_sparse(self):
        tf.reset_default_graph()

        z     = tf.Variable(tf.zeros((5, 2)), dtype=tf.float32)
        idx   = tf.placeholder(tf.int32)
        zi    = tf.gather(z, idx)
        zloss = tf.square(zi - [10.0, 5.0])

        psgld = pSGLD(learning_rate=0.4)
        train_op_psgld = psgld.minimize(zloss)

        sess = tf.InteractiveSession()
        sess.run(tf.global_variables_initializer())

        self.assertTrue(np.alltrue(sess.run(z) == 0.0))

        sess.run(train_op_psgld, feed_dict={idx: 3})
        zh = sess.run(z)
        self.assertTrue(np.alltrue(zh[[0, 1, 2, 4], :] == 0.0))
        self.assertTrue(zh[3, 0] > 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号