test_normalization.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:BlueWhale 作者: caffe2 项目源码 文件源码
def test_preprocessing_network(self):
        feature_value_map = preprocessing_util.read_data()
        normalization_parameters = normalization.identify_parameters(
            feature_value_map
        )
        test_features = self.preprocess(
            feature_value_map, normalization_parameters
        )

        net = core.Net("PreprocessingTestNet")
        preprocessor = PreprocessorNet(net, False)
        for feature_name in feature_value_map:
            workspace.FeedBlob(feature_name, np.array([0], dtype=np.int32))
            preprocessor.preprocess_blob(
                feature_name, normalization_parameters[feature_name]
            )

        workspace.CreateNet(net)

        for feature_name in feature_value_map:
            workspace.FeedBlob(feature_name, feature_value_map[feature_name])
        workspace.RunNetOnce(net)

        for feature_name in feature_value_map:
            normalized_features = workspace.FetchBlob(
                feature_name + "_preprocessed"
            )
            tolerance = 0.01
            if feature_name == 'boxcox':
                # At the limit, boxcox has some numerical instability
                tolerance = 0.1
            non_matching = np.where(
                np.logical_not(
                    np.isclose(
                        normalized_features,
                        test_features[feature_name],
                        rtol=tolerance,
                        atol=tolerance,
                    )
                )
            )
            self.assertTrue(
                np.all(
                    np.isclose(
                        normalized_features,
                        test_features[feature_name],
                        rtol=tolerance,
                        atol=tolerance,
                    )
                ), '{} does not match: {} {}'.format(
                    feature_name, normalized_features[non_matching].tolist(),
                    test_features[feature_name][non_matching].tolist()
                )
            )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号