def create_graph(device1, device2):
"""Create graph that keeps variable on device1 and
vector of ones/addition op on device2"""
tf.reset_default_graph()
dtype=tf.int32
params_size = 250*1000*FLAGS.data_mb # 1MB is 250k integers
with tf.device(device1):
params = tf.get_variable("params", [params_size], dtype,
initializer=tf.zeros_initializer)
with tf.device(device2):
# constant node gets placed on device1 because of simple_placer
# update = tf.constant(1, shape=[params_size], dtype=dtype)
update = tf.get_variable("update", [params_size], dtype,
initializer=tf.ones_initializer)
add_op = params.assign_add(update)
init_op = tf.initialize_all_variables()
return init_op, add_op
评论列表
文章目录