def scale_features(features, train):
"""Scale features, using test set to learn parameters.
Returns:
Scaled copy of features.
"""
if FLAGS.scaling is None:
return features
logging.info('Scaling features with %s', FLAGS.scaling)
if FLAGS.scaling == 'max_abs':
scaler = preprocessing.MaxAbsScaler()
elif FLAGS.scaling == 'standard':
scaler = preprocessing.StandardScaler()
else:
raise ValueError('Unrecognized scaling %s' % FLAGS.scaling)
scaler.fit(features[train])
return scaler.transform(features)
评论列表
文章目录