def correlation(task,load=True):
self = mytask
if load:
self.initialize(_load=True, _logging=False, _log_dir='other/')
data = []
for batch in self.iterate_minibatches('valid'):
xtrain, ytrain = batch
ytrain = np.eye(10)[ytrain]
feed_dict = {self.x: xtrain, self.y: ytrain, self.sigma0: 1., self.initial_keep_prob: task['initial_keep_prob'], self.is_training: False}
z = tf.get_collection('log_network')[-1]
batch_z = self.sess.run( z, feed_dict)
data.append(batch_z)
data = np.vstack(data)
data = data.reshape(data.shape[0],-1)
def normal_tc(c0):
c1i = np.diag(1./np.diag(c0))
p = np.matmul(c1i,c0)
return - .5 * np.linalg.slogdet(p)[1] / c0.shape[0]
c0 = np.cov( data, rowvar=False )
tc = normal_tc(c0)
print "Total correlation: %f" % tc
评论列表
文章目录