impl_helper_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:transform 作者: tensorflow 项目源码 文件源码
def testMakeOutputDict(self):
    schema = self.toSchema({
        'a': tf.FixedLenFeature(None, tf.int64),
        'b': tf.FixedLenFeature([], tf.float32),
        'c': tf.FixedLenFeature([1], tf.float32),
        'd': tf.FixedLenFeature([2, 2], tf.float32),
        'e': tf.VarLenFeature(tf.string),
        'f': tf.SparseFeature('idx', 'val', tf.float32, 10)
    })

    fetches = {
        'a': np.array([100, 200, 300]),
        'b': np.array([10.0, 20.0, 30.0]),
        'c': np.array([[40.0], [80.0], [120.0]]),
        'd': np.array([[[1.0, 2.0], [3.0, 4.0]],
                       [[5.0, 6.0], [7.0, 8.0]],
                       [[9.0, 10.0], [11.0, 12.0]]]),
        'e': tf.SparseTensorValue(
            indices=np.array([(0, 0), (0, 1), (0, 2), (2, 0), (2, 1), (2, 2)]),
            values=np.array(['doe', 'a', 'deer', 'a', 'female', 'deer']),
            dense_shape=(3, 3)),
        'f': tf.SparseTensorValue(
            indices=np.array([(0, 2), (0, 4), (0, 8), (1, 8), (1, 4)]),
            values=np.array([10.0, 20.0, 30.0, 40.0, 50.0]),
            dense_shape=(3, 20))
    }
    output_dict = impl_helper.make_output_dict(schema, fetches)
    self.assertSetEqual(set(six.iterkeys(output_dict)),
                        set(['a', 'b', 'c', 'd', 'e', 'f']))
    self.assertAllEqual(output_dict['a'], [100, 200, 300])
    self.assertAllEqual(output_dict['b'], [10.0, 20.0, 30.0])
    self.assertAllEqual(output_dict['c'], [[40.0], [80.0], [120.0]])
    self.assertAllEqual(output_dict['d'], [[[1.0, 2.0], [3.0, 4.0]],
                                           [[5.0, 6.0], [7.0, 8.0]],
                                           [[9.0, 10.0], [11.0, 12.0]]])
    self.assertAllEqual(output_dict['e'][0], ['doe', 'a', 'deer'])
    self.assertAllEqual(output_dict['e'][1], [])
    self.assertAllEqual(output_dict['e'][2], ['a', 'female', 'deer'])
    self.assertEqual(len(output_dict['f']), 2)
    self.assertAllEqual(output_dict['f'][0][0], [2, 4, 8])
    self.assertAllEqual(output_dict['f'][0][1], [8, 4])
    self.assertAllEqual(output_dict['f'][0][2], [])
    self.assertAllEqual(output_dict['f'][1][0], [10.0, 20.0, 30.0])
    self.assertAllEqual(output_dict['f'][1][1], [40.0, 50.0])
    self.assertAllEqual(output_dict['f'][1][2], [])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号