将tf.data.Dataset包装到tf.function中是否可以提高性能?

发布于 2021-01-29 16:20:58

给定以下两个示例,对亲笔签名是否有性能改进tf.data.Dataset

数据集不在tf.function中

import tensorflow as tf


class MyModel(tf.keras.Model):

    def call(self, inputs):
        return tf.ones([1, 1]) * inputs


model = MyModel()
model2 = MyModel()


@tf.function
def train_step(data):
    output = model(data)
    output = model2(output)
    return output


dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))

for data in dataset:
    train_step(data)

tf.function中的数据集

import tensorflow as tf


class MyModel(tf.keras.Model):

    def call(self, inputs):
        return tf.ones([1, 1]) * inputs


model = MyModel()
model2 = MyModel()


@tf.function
def train():
    dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
    def train_step(data):
        output = model(data)
        output = model2(output)
        return output
    for data in dataset:
        train_step(data)


train()
关注者
0
被浏览
256
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    添加@tf.function确实可以显着提高速度。看看这个:

    import tensorflow as tf
    
    data = tf.random.normal((1000, 10, 10, 1))
    dataset = tf.data.Dataset.from_tensors(data).batch(10)
    
    def iterate_1(dataset):
        for x in dataset:
            x = x
    
    @tf.function
    def iterate_2(dataset):
        for x in dataset:
            x = x
    
    %timeit -n 1000 iterate_1(dataset) # 1.46 ms ± 8.2 µs per loop
    %timeit -n 1000 iterate_2(dataset) # 239 µs ± 10.2 µs per loop
    

    如您所见,使用进行迭代的@tf.function速度提高了6倍以上。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看