def test_preprocessing_network(self):
feature_value_map = preprocessing_util.read_data()
normalization_parameters = normalization.identify_parameters(
feature_value_map
)
test_features = self.preprocess(
feature_value_map, normalization_parameters
)
net = core.Net("PreprocessingTestNet")
preprocessor = PreprocessorNet(net, False)
for feature_name in feature_value_map:
workspace.FeedBlob(feature_name, np.array([0], dtype=np.int32))
preprocessor.preprocess_blob(
feature_name, normalization_parameters[feature_name]
)
workspace.CreateNet(net)
for feature_name in feature_value_map:
workspace.FeedBlob(feature_name, feature_value_map[feature_name])
workspace.RunNetOnce(net)
for feature_name in feature_value_map:
normalized_features = workspace.FetchBlob(
feature_name + "_preprocessed"
)
tolerance = 0.01
if feature_name == 'boxcox':
# At the limit, boxcox has some numerical instability
tolerance = 0.1
non_matching = np.where(
np.logical_not(
np.isclose(
normalized_features,
test_features[feature_name],
rtol=tolerance,
atol=tolerance,
)
)
)
self.assertTrue(
np.all(
np.isclose(
normalized_features,
test_features[feature_name],
rtol=tolerance,
atol=tolerance,
)
), '{} does not match: {} {}'.format(
feature_name, normalized_features[non_matching].tolist(),
test_features[feature_name][non_matching].tolist()
)
)
评论列表
文章目录