def build_model(self):
"""Find the model and build the graph."""
# Convert feature_names and feature_sizes to lists of values.
feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
FLAGS.feature_names, FLAGS.feature_sizes)
if FLAGS.frame_features:
if FLAGS.frame_only:
reader = readers.YT8MFrameFeatureOnlyReader(
feature_names=feature_names, feature_sizes=feature_sizes)
else:
reader = readers.YT8MFrameFeatureReader(
feature_names=feature_names, feature_sizes=feature_sizes)
else:
reader = readers.YT8MAggregatedFeatureReader(
feature_names=feature_names, feature_sizes=feature_sizes)
# Find the model.
model = find_class_by_name(FLAGS.model,
[labels_embedding])()
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses_embedding])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
logging.info("%s: Built graph.", task_as_string(self.task))
return tf.train.Saver(max_to_keep=2, keep_checkpoint_every_n_hours=0.25)
评论列表
文章目录