def getStatsEigen(self, stats=None):
if len(self.stats_eigen) == 0:
stats_eigen = {}
if stats is None:
stats = self.stats
tmpEigenCache = {}
with tf.device('/cpu:0'):
for var in stats:
for key in ['fprop_concat_stats', 'bprop_concat_stats']:
for stats_var in stats[var][key]:
if stats_var not in tmpEigenCache:
stats_dim = stats_var.get_shape()[1].value
e = tf.Variable(tf.ones(
[stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False)
Q = tf.Variable(tf.diag(tf.ones(
[stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False)
stats_eigen[stats_var] = {'e': e, 'Q': Q}
tmpEigenCache[
stats_var] = stats_eigen[stats_var]
else:
stats_eigen[stats_var] = tmpEigenCache[
stats_var]
self.stats_eigen = stats_eigen
return self.stats_eigen
评论列表
文章目录