cifar.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:information-dropout 作者: ucla-vision 项目源码 文件源码
def correlation(task,load=True):
    self = mytask
    if load:
        self.initialize(_load=True, _logging=False, _log_dir='other/')
    data = []
    for batch in self.iterate_minibatches('valid'):
        xtrain, ytrain = batch
        ytrain = np.eye(10)[ytrain]
        feed_dict = {self.x: xtrain, self.y: ytrain, self.sigma0: 1., self.initial_keep_prob: task['initial_keep_prob'],  self.is_training: False}
        z = tf.get_collection('log_network')[-1]
        batch_z = self.sess.run( z, feed_dict)
        data.append(batch_z)
    data = np.vstack(data)
    data = data.reshape(data.shape[0],-1)
    def normal_tc(c0):
        c1i = np.diag(1./np.diag(c0))
        p = np.matmul(c1i,c0)
        return - .5 * np.linalg.slogdet(p)[1] / c0.shape[0]
    c0 = np.cov( data, rowvar=False )
    tc = normal_tc(c0)
    print "Total correlation: %f" % tc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号