TensorFlow:有没有一种方法可以测量模型的FLOPS?

发布于 2021-01-29 17:54:54

在此问题中找到了我能找到的最接近的示例:https
:
//github.com/tensorflow/tensorflow/issues/899

使用此最小的可复制代码:

import tensorflow as tf
import tensorflow.python.framework.ops as ops 
g = tf.Graph()
with g.as_default():
  A = tf.Variable(tf.random_normal( [25,16] ))
  B = tf.Variable(tf.random_normal( [16,9] ))
  C = tf.matmul(A,B) # shape=[25,9]
for op in g.get_operations():
  flops = ops.get_stats_for_node_def(g, op.node_def, 'flops').value
  if flops is not None:
    print 'Flops should be ~',2*25*16*9
    print '25 x 25 x 9 would be',2*25*25*9 # ignores internal dim, repeats first
    print 'TF stats gives',flops

但是,返回的FLOPS始终为“无”。有没有一种方法可以具体测量FLOPS,尤其是PB文件?

关注者
0
被浏览
143
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    有点晚了,但也许将来对某些访客有帮助。对于您的示例,我成功测试了以下代码段:

    g = tf.Graph()
    run_meta = tf.RunMetadata()
    with g.as_default():
        A = tf.Variable(tf.random_normal( [25,16] ))
        B = tf.Variable(tf.random_normal( [16,9] ))
        C = tf.matmul(A,B) # shape=[25,9]
    
        opts = tf.profiler.ProfileOptionBuilder.float_operation()    
        flops = tf.profiler.profile(g, run_meta=run_meta, cmd='op', options=opts)
        if flops is not None:
            print('Flops should be ~',2*25*16*9)
            print('25 x 25 x 9 would be',2*25*25*9) # ignores internal dim, repeats first
            print('TF stats gives',flops.total_float_ops)
    

    也可以将分析器与Keras以下代码段结合使用:

    import tensorflow as tf
    import keras.backend as K
    from keras.applications.mobilenet import MobileNet
    
    run_meta = tf.RunMetadata()
    with tf.Session(graph=tf.Graph()) as sess:
        K.set_session(sess)
        net = MobileNet(alpha=.75, input_tensor=tf.placeholder('float32', shape=(1,32,32,3)))
    
        opts = tf.profiler.ProfileOptionBuilder.float_operation()    
        flops = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)
    
        opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()    
        params = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)
    
        print("{:,} --- {:,}".format(flops.total_float_ops, params.total_parameters))
    

    希望我能帮上忙!



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看