def required_feeds(cls, tensor):
if hasattr(tensor, 'required_feeds'):
# Return cached result
return tensor.required_feeds
else:
# Get feeds required by all inputs
if isinstance(tensor, list):
input_tensors = tensor
else:
op = tensor if isinstance(tensor, tf.Operation) else tensor.op
input_tensors = list(op.inputs) + list(op.control_inputs)
from networks import inputs
feeds = inputs.RequiredFeeds()
for input_tensor in input_tensors:
feeds = feeds.merge(cls.required_feeds(input_tensor))
# Cache results
if not isinstance(tensor, list):
tensor.required_feeds = feeds
return feeds
评论列表
文章目录