api.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:transform 作者: tensorflow 项目源码 文件源码
def __init__(self, fn, args):
    def _decompose_tensors(tensor_list):
      result = []
      for tensor in tensor_list:
        if isinstance(tensor, tf.SparseTensor):
          result.append(tensor.indices)
          result.append(tensor.values)
          result.append(tensor.dense_shape)
        else:
          result.append(tensor)
      return result

    def _copy_tensor(tensor):
      if isinstance(tensor, tf.SparseTensor):
        return tf.SparseTensor(
            tf.identity(tensor.indices),
            tf.identity(tensor.values),
            tf.identity(tensor.dense_shape))
      else:
        return tf.identity(tensor)

    # Apply fn to its args, keeping track of any table initializers that are
    # added while fn is running, and also checking that no analyzers are added
    # while fn is running.
    all_table_initializers = tf.get_collection_ref(
        tf.GraphKeys.TABLE_INITIALIZERS)
    all_analyzers = tf.get_collection_ref(analyzers.ANALYZER_COLLECTION)
    original_num_table_initializers = len(all_table_initializers)
    original_num_analyzers = len(all_analyzers)
    output = fn(*args)
    if len(all_analyzers) != original_num_analyzers:
      raise ValueError(
          'One or more `Analyzer`s were created while inside '
          'FunctionApplication.__init__')

    # Set inputs and outputs of this op, flattening inputs and outputs into a
    # list of tensors, but storing outputs in the original format for the return
    # value of `apply_function`.
    self._table_initializers = all_table_initializers[
        original_num_table_initializers:]
    self._inputs = _decompose_tensors(args)
    # When traversing the graph, there isn't a clean way to handle `Map`s whose
    # inputs and outputs overlap.  Therefore we apply tf.identity to all outputs
    # to ensure the outputs and inputs don't overlap.
    if isinstance(output, tuple):
      self._user_output = [_copy_tensor(tensor) for tensor in output]
      self._outputs = _decompose_tensors(self._user_output)
    else:
      self._user_output = _copy_tensor(output)
      self._outputs = _decompose_tensors([self._user_output])

    tf.add_to_collection(FUNCTION_APPLICATION_COLLECTION, self)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号