pipelines_common_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def testRandomPartition(self):
    random_partition = pipelines_common.RandomPartition(
        str, ['a', 'b', 'c'], [0.1, 0.4])
    random_nums = [0.55, 0.05, 0.34, 0.99]
    choices = ['c', 'a', 'b', 'c']
    random_partition.rand_func = functools.partial(six.next, iter(random_nums))
    self.assertEqual(random_partition.input_type, str)
    self.assertEqual(random_partition.output_type,
                     {'a': str, 'b': str, 'c': str})
    for i, s in enumerate(['hello', 'qwerty', '1234567890', 'zxcvbnm']):
      results = random_partition.transform(s)
      self.assertTrue(isinstance(results, dict))
      self.assertEqual(set(results.keys()), set(['a', 'b', 'c']))
      self.assertEqual(len(results.values()), 3)
      self.assertEqual(len([l for l in results.values() if l == []]), 2)  # pylint: disable=g-explicit-bool-comparison
      self.assertEqual(results[choices[i]], [s])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号