def show_vars(logger=None, trainable_scopes=None):
printer = logger.info if logger is not None else print
all_vars = set(tf.global_variables())
trainable_vars = set(trainable_variables(trainable_scopes))
non_trainable_vars = all_vars.difference(trainable_vars)
local_vars = set(tf.local_variables())
class nonlocal: pass
nonlocal.num_params = {}
def show_var_info(vars, var_type):
printer('\n---%s vars in model:' % var_type)
name_shapes = map(lambda v: (v.name, v.get_shape()), vars)
total_params = 0
for n, s in sorted(name_shapes, key=lambda ns: ns[0]):
printer('%s %s' % (n, s))
total_params += np.prod(s.as_list())
nonlocal.num_params[var_type] = total_params
show_var_info(trainable_vars, 'Trainable')
show_var_info(non_trainable_vars, 'Non Trainable')
show_var_info(local_vars, 'Local')
printer('Total number of params:')
printer(pprint.pformat(nonlocal.num_params))
评论列表
文章目录