core.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tensorlight 作者: bsautermeister 项目源码 文件源码
def show_trainable_parameters(verbose=False):
    """Shows the number of trainable parameters in this graph.
    Parameters
    ----------
    verbose: Boolean, optional
        Show additional information and list the number of trainable
        variables per variable, not just the total sum.
    """
    total_width = 80
    trainable_vars = tf.trainable_variables()

    if len(trainable_vars) == 0:
        print("No model-params found.")
        return

    if verbose:
        print("-" * total_width)

    total_parameters = 0
    groups = {}
    for var in trainable_vars:
        # shape is an array of tf.Dimension
        shape = var.get_shape()
        var_params = 1
        for dim in shape:
            var_params *= dim.value
        if verbose:
            print("{:69} | {:8d}".format(var.name, var_params))

        total_parameters += var_params

        group_name = var.name.split('/')[0]
        if group_name in groups:
            groups[group_name] += var_params
        else:
            groups.update({group_name: var_params})

    print("-" * total_width)
    for group, count in groups.iteritems():
        print("{:69} | {:8d}".format(group, count))
    print("=" * total_width)
    print("{:69} | {:8d}".format("TOTAL", total_parameters))
    print("-" * total_width)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号