util.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tefla 作者: litan 项目源码 文件源码
def show_vars(logger=None, trainable_scopes=None):
    printer = logger.info if logger is not None else print
    all_vars = set(tf.global_variables())
    trainable_vars = set(trainable_variables(trainable_scopes))
    non_trainable_vars = all_vars.difference(trainable_vars)
    local_vars = set(tf.local_variables())

    class nonlocal: pass

    nonlocal.num_params = {}

    def show_var_info(vars, var_type):
        printer('\n---%s vars in model:' % var_type)
        name_shapes = map(lambda v: (v.name, v.get_shape()), vars)
        total_params = 0
        for n, s in sorted(name_shapes, key=lambda ns: ns[0]):
            printer('%s %s' % (n, s))
            total_params += np.prod(s.as_list())
        nonlocal.num_params[var_type] = total_params

    show_var_info(trainable_vars, 'Trainable')
    show_var_info(non_trainable_vars, 'Non Trainable')
    show_var_info(local_vars, 'Local')
    printer('Total number of params:')
    printer(pprint.pformat(nonlocal.num_params))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号