def __call__(self, batch: Mapping[TensorPort, np.ndarray],
goal_ports: List[TensorPort] = None) -> Mapping[TensorPort, np.ndarray]:
"""Runs a batch and returns values/outputs for specified goal ports.
Args:
batch: mapping from ports to values
goal_ports: optional output ports, defaults to output_ports of this module will be returned
Returns:
A mapping from goal ports to tensors.
"""
goal_ports = goal_ports or self.output_ports
feed_dict = self.convert_to_feed_dict(batch)
goal_tensors = {p: self.tensors[p] for p in goal_ports
if p in self.output_ports or p in self.training_output_ports}
outputs = self.tf_session.run(goal_tensors, feed_dict)
for p in goal_ports:
if p not in outputs and p in batch:
outputs[p] = batch[p]
return outputs
评论列表
文章目录