benchmark.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def search_queue_params():
    df = []

    data_batch_sizes = np.logspace(0, 8, num=9, base=2, dtype=int)
    capacities = np.logspace(0, 12, num=13, base=2, dtype=int)
    nthreads = np.logspace(0, 5, num=6, base=2, dtype=int)

    for nth in nthreads:
        for data_batch_size in data_batch_sizes:
            for capacity in capacities:
                cap = nth * capacity

                tf.reset_default_graph()
                d = DataHDF5(batch_size=data_batch_size)
                queue = data.Queue(d.node, d,
                                   queue_type='fifo',
                                   batch_size=BATCH_SIZE,
                                   capacity=cap,
                                   n_threads=nth)
                queue.kind = '{} / {} / {}'.format(nth, data_batch_size, capacity)
                durs = time_tf(queue)
                durs['data batch size'] = data_batch_size
                durs['queue capacity'] = cap
                durs['nthreads'] = nth
                df.append(durs)
                d.cleanup()

    df = pandas.concat(df, ignore_index=True)
    df.kind = df.kind.astype('category', ordered=True, categories=df.kind.unique())
    df.to_pickle('/home/qbilius/mh17/computed/search_queue_params.pkl')
    print(df.groupby(['nthreads', 'data batch size', 'queue capacity']).dur.mean())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号