variables.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:Tensormodels 作者: asheshjain399 项目源码 文件源码
def global_step(device=''):
  """Returns the global step variable.

  Args:
    device: Optional device to place the variable. It can be an string or a
      function that is called to get the device for the variable.

  Returns:
    the tensor representing the global step variable.
  """
  global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
  if global_step_ref:
    return global_step_ref[0]
  else:
    collections = [
        VARIABLES_TO_RESTORE,
        tf.GraphKeys.VARIABLES,
        tf.GraphKeys.GLOBAL_STEP,
    ]
    # Get the device for the variable.
    with tf.device(variable_device(device, 'global_step')):
      return tf.get_variable('global_step', shape=[], dtype=tf.int64,
                             initializer=tf.zeros_initializer,
                             trainable=False, collections=collections)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号