test_data_provider.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:deep_metric_learning 作者: ronekko 项目源码 文件源码
def check_generate_valid_indexes(self, num_examples, batch_size):
        T = 90
        scheme = EpochwiseShuffledInfiniteScheme(num_examples, batch_size)
        uniquenesses = []
        all_indexes = []
        for i in range(T):
            indexes = next(scheme)
            is_unique = len(indexes) == len(np.unique(indexes))
            uniquenesses.append(is_unique)
            all_indexes.append(indexes)

        assert np.all(uniquenesses)

        counts = np.bincount(np.concatenate(all_indexes).ravel())
        expected_counts = [batch_size * T // num_examples] * num_examples
        assert np.array_equal(counts, expected_counts)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号