train_embedding.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def build_model(self):
    """Find the model and build the graph."""

    # Convert feature_names and feature_sizes to lists of values.
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      if FLAGS.frame_only:
          reader = readers.YT8MFrameFeatureOnlyReader(
              feature_names=feature_names, feature_sizes=feature_sizes)
      else:
          reader = readers.YT8MFrameFeatureReader(
              feature_names=feature_names, feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(
          feature_names=feature_names, feature_sizes=feature_sizes)

    # Find the model.
    model = find_class_by_name(FLAGS.model,
                               [labels_embedding])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses_embedding])()
    optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])

    build_graph(reader=reader,
                 model=model,
                 optimizer_class=optimizer_class,
                 clip_gradient_norm=FLAGS.clip_gradient_norm,
                 train_data_pattern=FLAGS.train_data_pattern,
                 label_loss_fn=label_loss_fn,
                 base_learning_rate=FLAGS.base_learning_rate,
                 learning_rate_decay=FLAGS.learning_rate_decay,
                 learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
                 regularization_penalty=FLAGS.regularization_penalty,
                 num_readers=FLAGS.num_readers,
                 batch_size=FLAGS.batch_size,
                 num_epochs=FLAGS.num_epochs)

    logging.info("%s: Built graph.", task_as_string(self.task))

    return tf.train.Saver(max_to_keep=2, keep_checkpoint_every_n_hours=0.25)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号