def test_aggregator():
"""aggregate value on muliple devices"""
kv = init_kv()
# devices
num_devs = 4
devs = [mx.Context('cpu', i) for i in range(num_devs)]
# single
vals = [mx.nd.ones(shape, d) for d in devs]
kv.push(3, vals)
kv.pull(3, out = vals)
for v in vals:
check_diff_to_scalar(v, num_devs)
# list
vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys)
kv.push(keys, vals)
kv.pull(keys, out = vals)
for vv in vals:
for v in vv:
check_diff_to_scalar(v, num_devs * 2.0)
评论列表
文章目录