def create_graph(device0, device1):
"""Create graph that keeps var1 on device0, var2 on device1 and adds them"""
tf.reset_default_graph()
dtype=tf.int32
params_size = 250*1000*FLAGS.data_mb # 1MB is 250k integers
with tf.device(device0):
var1 = tf.get_variable("var1", [params_size], dtype,
initializer=tf.ones_initializer())
with tf.device(device1):
var2 = tf.get_variable("var2", [params_size], dtype,
initializer=tf.ones_initializer())
add_op = var1.assign_add(var2)
init_op = tf.global_variables_initializer()
return init_op, add_op
评论列表
文章目录