def get_estimator(args, output_dir, features, stats, target_vocab_size):
# Check layers used for dnn models.
if is_dnn_model(args.model) and not args.hidden_layer_sizes:
raise ValueError('--hidden-layer-size* must be used with DNN models')
if is_linear_model(args.model) and args.hidden_layer_sizes:
raise ValueError('--hidden-layer-size* cannot be used with linear models')
# Build tf.learn features
feature_columns = build_feature_columns(features, stats, args.model)
# Set how often to run checkpointing in terms of steps.
config = tf.contrib.learn.RunConfig(
save_checkpoints_steps=args.min_eval_frequency)
train_dir = os.path.join(output_dir, 'train')
if args.model == 'dnn_regression':
estimator = tf.contrib.learn.DNNRegressor(
feature_columns=feature_columns,
hidden_units=args.hidden_layer_sizes,
config=config,
model_dir=train_dir,
optimizer=tf.train.AdamOptimizer(
args.learning_rate, epsilon=args.epsilon))
elif args.model == 'linear_regression':
estimator = tf.contrib.learn.LinearRegressor(
feature_columns=feature_columns,
config=config,
model_dir=train_dir,
optimizer=tf.train.FtrlOptimizer(
args.learning_rate,
l1_regularization_strength=args.l1_regularization,
l2_regularization_strength=args.l2_regularization))
elif args.model == 'dnn_classification':
estimator = tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units=args.hidden_layer_sizes,
n_classes=target_vocab_size,
config=config,
model_dir=train_dir,
optimizer=tf.train.AdamOptimizer(
args.learning_rate, epsilon=args.epsilon))
elif args.model == 'linear_classification':
estimator = tf.contrib.learn.LinearClassifier(
feature_columns=feature_columns,
n_classes=target_vocab_size,
config=config,
model_dir=train_dir,
optimizer=tf.train.FtrlOptimizer(
args.learning_rate,
l1_regularization_strength=args.l1_regularization,
l2_regularization_strength=args.l2_regularization))
else:
raise ValueError('bad --model-type value')
return estimator
评论列表
文章目录