def testRandomPartition(self):
random_partition = pipelines_common.RandomPartition(
str, ['a', 'b', 'c'], [0.1, 0.4])
random_nums = [0.55, 0.05, 0.34, 0.99]
choices = ['c', 'a', 'b', 'c']
random_partition.rand_func = functools.partial(six.next, iter(random_nums))
self.assertEqual(random_partition.input_type, str)
self.assertEqual(random_partition.output_type,
{'a': str, 'b': str, 'c': str})
for i, s in enumerate(['hello', 'qwerty', '1234567890', 'zxcvbnm']):
results = random_partition.transform(s)
self.assertTrue(isinstance(results, dict))
self.assertEqual(set(results.keys()), set(['a', 'b', 'c']))
self.assertEqual(len(results.values()), 3)
self.assertEqual(len([l for l in results.values() if l == []]), 2) # pylint: disable=g-explicit-bool-comparison
self.assertEqual(results[choices[i]], [s])
评论列表
文章目录