使用tensorflow的Dataset管道,如何“命名” map操作的结果?

发布于 2021-01-29 16:22:14

我在下面有map函数(可运行示例),该函数输入astring并输出astring和an integer

tf.data.Dataset.from_tensor_slices我命名原始输入'filenames'。但是,当我从map函数返回值时,map_element_counts我只能返回一个元组(返回字典会生成异常)。

有没有一种方法可以命名从map_element_counts函数返回的2个元素?

import tensorflow as tf

filelist = ['fileA_6', 'fileB_10', 'fileC_7']

def map_element_counts(fname):
  # perform operations outside of tensorflow
  return 'test', 10

ds = tf.data.Dataset.from_tensor_slices({'filenames': filelist})
ds = ds.map(map_func=lambda x: tf.py_func(
  func=map_element_counts, inp=[x['filenames']], Tout=[tf.string, tf.int64]
))
element = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
  print(sess.run(element))

结果:

(b'test', 10)

所需结果:

{'elementA': b'test', 'elementB': 10)

添加的详细信息:

当我return {'elementA': 'test', 'elementB': 10}收到此异常时:

tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type dict
关注者
0
被浏览
48
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    申请tf.py_func内部ds.map作品。

    我创建了一个非常简单的文件作为示例。我只是在里面写10。

    dummy_file.txt:

    10
    

    这里是脚本:

    import tensorflow as tf
    
    filelist = ['dummy_file.txt', 'dummy_file.txt', 'dummy_file.txt']
    
    
    def py_func(input):
        # perform operations outside of tensorflow
        parsed_txt_file = int(input)
        return 'test', parsed_txt_file
    
    
    def map_element_counts(fname):
        # let tensorflow read the text file
        file_string = tf.read_file(fname['filenames'])
        # then use python function on the extracted string
        a, b = tf.py_func(
                        func=py_func, inp=[file_string], Tout=[tf.string, tf.int64]
                        )
        return {'elementA': a, 'elementB': b, 'file': fname['filenames']}
    
    ds = tf.data.Dataset.from_tensor_slices({'filenames': filelist})
    ds = ds.map(map_element_counts)
    element = ds.make_one_shot_iterator().get_next()
    
    with tf.Session() as sess:
        print(sess.run(element))
        print(sess.run(element))
        print(sess.run(element))
    

    输出:

    {'file': b'dummy_file.txt', 'elementA': b'test', 'elementB': 10}
    {'file': b'dummy_file.txt', 'elementA': b'test', 'elementB': 10}
    {'file': b'dummy_file.txt', 'elementA': b'test', 'elementB': 10}
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看