task.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pydatalab 作者: googledatalab 项目源码 文件源码
def get_estimator(args, output_dir, features, stats, target_vocab_size):
  # Check layers used for dnn models.
  if is_dnn_model(args.model) and not args.hidden_layer_sizes:
    raise ValueError('--hidden-layer-size* must be used with DNN models')
  if is_linear_model(args.model) and args.hidden_layer_sizes:
    raise ValueError('--hidden-layer-size* cannot be used with linear models')

  # Build tf.learn features
  feature_columns = build_feature_columns(features, stats, args.model)

  # Set how often to run checkpointing in terms of steps.
  config = tf.contrib.learn.RunConfig(
      save_checkpoints_steps=args.min_eval_frequency)

  train_dir = os.path.join(output_dir, 'train')
  if args.model == 'dnn_regression':
    estimator = tf.contrib.learn.DNNRegressor(
        feature_columns=feature_columns,
        hidden_units=args.hidden_layer_sizes,
        config=config,
        model_dir=train_dir,
        optimizer=tf.train.AdamOptimizer(
            args.learning_rate, epsilon=args.epsilon))
  elif args.model == 'linear_regression':
    estimator = tf.contrib.learn.LinearRegressor(
        feature_columns=feature_columns,
        config=config,
        model_dir=train_dir,
        optimizer=tf.train.FtrlOptimizer(
            args.learning_rate,
            l1_regularization_strength=args.l1_regularization,
            l2_regularization_strength=args.l2_regularization))
  elif args.model == 'dnn_classification':
    estimator = tf.contrib.learn.DNNClassifier(
        feature_columns=feature_columns,
        hidden_units=args.hidden_layer_sizes,
        n_classes=target_vocab_size,
        config=config,
        model_dir=train_dir,
        optimizer=tf.train.AdamOptimizer(
            args.learning_rate, epsilon=args.epsilon))
  elif args.model == 'linear_classification':
    estimator = tf.contrib.learn.LinearClassifier(
        feature_columns=feature_columns,
        n_classes=target_vocab_size,
        config=config,
        model_dir=train_dir,
        optimizer=tf.train.FtrlOptimizer(
            args.learning_rate,
            l1_regularization_strength=args.l1_regularization,
            l2_regularization_strength=args.l2_regularization))
  else:
    raise ValueError('bad --model-type value')

  return estimator
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号