local_distributed_benchmark.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:stuff 作者: yaroslavvb 项目源码 文件源码
def create_graph(device1, device2):
  """Create graph that keeps variable on device1 and
  vector of ones/addition op on device2"""

  tf.reset_default_graph()
  dtype=tf.int32
  params_size = 250*1000*FLAGS.data_mb # 1MB is 250k integers

  with tf.device(device1):
    params = tf.get_variable("params", [params_size], dtype,
                             initializer=tf.zeros_initializer)
  with tf.device(device2):
    # constant node gets placed on device1 because of simple_placer
    #    update = tf.constant(1, shape=[params_size], dtype=dtype)
    update = tf.get_variable("update", [params_size], dtype,
                             initializer=tf.ones_initializer)
    add_op = params.assign_add(update)

  init_op = tf.initialize_all_variables()
  return init_op, add_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号