def __init__(self, fn, args):
def _decompose_tensors(tensor_list):
result = []
for tensor in tensor_list:
if isinstance(tensor, tf.SparseTensor):
result.append(tensor.indices)
result.append(tensor.values)
result.append(tensor.dense_shape)
else:
result.append(tensor)
return result
def _copy_tensor(tensor):
if isinstance(tensor, tf.SparseTensor):
return tf.SparseTensor(
tf.identity(tensor.indices),
tf.identity(tensor.values),
tf.identity(tensor.dense_shape))
else:
return tf.identity(tensor)
# Apply fn to its args, keeping track of any table initializers that are
# added while fn is running, and also checking that no analyzers are added
# while fn is running.
all_table_initializers = tf.get_collection_ref(
tf.GraphKeys.TABLE_INITIALIZERS)
all_analyzers = tf.get_collection_ref(analyzers.ANALYZER_COLLECTION)
original_num_table_initializers = len(all_table_initializers)
original_num_analyzers = len(all_analyzers)
output = fn(*args)
if len(all_analyzers) != original_num_analyzers:
raise ValueError(
'One or more `Analyzer`s were created while inside '
'FunctionApplication.__init__')
# Set inputs and outputs of this op, flattening inputs and outputs into a
# list of tensors, but storing outputs in the original format for the return
# value of `apply_function`.
self._table_initializers = all_table_initializers[
original_num_table_initializers:]
self._inputs = _decompose_tensors(args)
# When traversing the graph, there isn't a clean way to handle `Map`s whose
# inputs and outputs overlap. Therefore we apply tf.identity to all outputs
# to ensure the outputs and inputs don't overlap.
if isinstance(output, tuple):
self._user_output = [_copy_tensor(tensor) for tensor in output]
self._outputs = _decompose_tensors(self._user_output)
else:
self._user_output = _copy_tensor(output)
self._outputs = _decompose_tensors([self._user_output])
tf.add_to_collection(FUNCTION_APPLICATION_COLLECTION, self)
评论列表
文章目录