benchmark.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def time_tf(data):
    m = model.alexnet_nonorm(data.batch['data'])
    targets = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(m.output, data.batch['labels']))

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    # start our custom queue runner's threads
    if hasattr(data, 'start_threads'):
        data.start_threads(sess)

    durs = []
    for step in tqdm.trange(NSTEPS, desc='running ' + data.kind):
        start_time = time.time()
        if hasattr(data, 'start_threads') or not hasattr(data, 'next'):
            sess.run(targets)
        else:
            batch = data.next()
            feed_dict = {node: batch[name] for name, node in data.batch.items()}
            sess.run(targets, feed_dict=feed_dict)
        end_time = time.time()
        durs.append([data.kind, step, end_time - start_time])

    if hasattr(data, 'stop_threads'):
        data.stop_threads(sess)

    sess.close()

    durs = pandas.DataFrame(durs, columns=['kind', 'stepno', 'dur'])
    return durs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号