def testPS(self):
deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1)
self.assertDeviceEqual(deploy_config.clone_device(0),
'/job:worker')
self.assertEqual(deploy_config.clone_scope(0), '')
self.assertDeviceEqual(deploy_config.optimizer_device(),
'/job:worker/device:CPU:0')
self.assertDeviceEqual(deploy_config.inputs_device(),
'/job:worker/device:CPU:0')
with tf.device(deploy_config.variables_device()):
a = tf.Variable(0)
b = tf.Variable(0)
c = tf.no_op()
d = slim.variable('a', [],
caching_device=deploy_config.caching_device())
self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
self.assertDeviceEqual(a.device, a.value().device)
self.assertDeviceEqual(b.device, '/job:ps/task:0/device:CPU:0')
self.assertDeviceEqual(b.device, b.value().device)
self.assertDeviceEqual(c.device, '')
self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
self.assertDeviceEqual(d.value().device, '')
model_deploy_test.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录