def create_global_step(graph=None):
"""Create global step tensor in graph.
Args:
graph: The graph in which to create the global step. If missing, use default
graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step key is already defined.
"""
graph = ops.get_default_graph() if graph is None else graph
if get_global_step(graph) is not None:
raise ValueError('"global_step" already exists.')
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP]
return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64,
initializer=init_ops.zeros_initializer, trainable=False,
collections=collections)
评论列表
文章目录