def test_updater(dev = 'cpu'):
"""updater"""
kv = init_kv()
kv._set_updater(updater)
# devices
num_devs = 4
devs = [mx.Context(dev, i) for i in range(num_devs)]
# single
vals = [mx.nd.ones(shape, d) for d in devs]
kv.push(3, vals)
kv.pull(3, out = vals)
for v in vals:
check_diff_to_scalar(v, num_devs)
# list
vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys)
num_push = 4
for i in range(num_push):
kv.push(keys, vals)
kv.pull(keys, out = vals)
for vv in vals:
for v in vv:
check_diff_to_scalar(v, num_devs * num_push)
评论列表
文章目录