test_memory_preformance.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:deep_rl_vizdoom 作者: mihahauke 项目源码 文件源码
def test_memory(insertions, samples, img_shape, misc_len, batch_size, capacity, img_dtype=np.float32):
    print("image shape:", img_shape)
    print("misc vector lenght:", misc_len)
    print("batchsize:", batch_size)
    print("capacity:", capacity)
    print("image data type:", img_dtype.__name__)
    memory = ReplayMemory(img_shape, misc_len, capacity, batch_size)
    if img_dtype != np.float32:
        s = [(np.random.random(img_shape) * 255).astype(img_dtype), np.random.random(misc_len).astype(np.float32)]
        s2 = [(np.random.random(img_shape) * 255).astype(img_dtype), np.random.random(misc_len).astype(np.float32)]
    else:
        s = [np.random.random(img_shape).astype(img_dtype), np.random.random(misc_len).astype(np.float32)]
        s2 = [np.random.random(img_shape).astype(img_dtype), np.random.random(misc_len).astype(np.float32)]
    a = 0
    r = 1.0
    terminal = False
    for _ in trange(capacity, leave=False, desc="Prefilling memory."):
        memory.add_transition(s, a, s2, r, terminal)

    start = time()
    for _ in trange(insertions, leave=False, desc="Testing insertions speed"):
        memory.add_transition(s, a, s2, r, terminal)
    inserts_time = time() - start

    start = time()
    for _ in trange(samples, leave=False, desc="Testing sampling speed"):
        sample = memory.get_sample()
    sample_time = time() - start

    print("\t{:0.1f} insertions/s. 1k insertions in: {:0.2f}s".format(insertions / inserts_time,
                                                                      inserts_time / insertions * 1000))
    print("\t{:0.1f} samples/s. 1k samples in: {:0.2f}s".format(samples / sample_time, sample_time / samples * 1000))
    print()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号