def start_server(cluster, task):
"""Creates a Server.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
if not task.type:
raise ValueError("%s: The task type must be specified." %
task_as_string(task))
if task.index is None:
raise ValueError("%s: The task index must be specified." %
task_as_string(task))
# Create and start a server.
return tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol="grpc",
job_name=task.type,
task_index=task.index)
python类train()的实例源码
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
def build_model(self, model, reader):
"""Find the model and build the graph."""
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)
def start_server(cluster, task):
"""Creates a Server.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
if not task.type:
raise ValueError("%s: The task type must be specified." %
task_as_string(task))
if task.index is None:
raise ValueError("%s: The task index must be specified." %
task_as_string(task))
# Create and start a server.
return tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol="grpc",
job_name=task.type,
task_index=task.index)
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
def build_model(self, model, reader):
"""Find the model and build the graph."""
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)
def start_server(cluster, task):
"""Creates a Server.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
if not task.type:
raise ValueError("%s: The task type must be specified." %
task_as_string(task))
if task.index is None:
raise ValueError("%s: The task index must be specified." %
task_as_string(task))
# Create and start a server.
return tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol="grpc",
job_name=task.type,
task_index=task.index)
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
def build_model(self, model, reader):
"""Find the model and build the graph."""
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)
def start_server(cluster, task):
"""Creates a Server.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
if not task.type:
raise ValueError("%s: The task type must be specified." %
task_as_string(task))
if task.index is None:
raise ValueError("%s: The task index must be specified." %
task_as_string(task))
# Create and start a server.
return tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol="grpc",
job_name=task.type,
task_index=task.index)
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
def build_model(self, model, reader):
"""Find the model and build the graph."""
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)
def start_server(cluster, task):
"""Creates a Server.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
if not task.type:
raise ValueError("%s: The task type must be specified." %
task_as_string(task))
if task.index is None:
raise ValueError("%s: The task index must be specified." %
task_as_string(task))
# Create and start a server.
return tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol="grpc",
job_name=task.type,
task_index=task.index)
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
def build_model(self, model, reader):
"""Find the model and build the graph."""
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=2.0)