vanilla_vae.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:LifelongVAE 作者: jramapuram 项目源码 文件源码
def _create_optimizer(self, tvars, cost, lr):
        # optimizer = tf.train.rmspropoptimizer(self.learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)

        print 'there are %d trainable vars in cost %s\n' % (len(tvars), cost.name)
        grads = tf.gradients(cost, tvars)

        # DEBUG: exploding gradients test with this:
        # for index in range(len(grads)):
        #     if grads[index] is not None:
        #         gradstr = "\n grad [%i] | tvar [%s] =" % (index, tvars[index].name)
        #         grads[index] = tf.Print(grads[index], [grads[index]], gradstr, summarize=100)

        # grads, _ = tf.clip_by_global_norm(grads, 5.0)
        self.grad_norm = tf.norm(tf.concat([tf.reshape(t, [-1]) for t in grads],
                                           axis=0))
        return optimizer.apply_gradients(zip(grads, tvars))
        # return tf.train.AdamOptimizer(learning_rate=lr).minimize(cost, var_list=tvars)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号