def __call__(self, batch: Mapping[TensorPort, np.ndarray],
goal_ports: List[TensorPort] = None) -> Mapping[TensorPort, np.ndarray]:
"""Runs a batch and returns values/outputs for specified goal ports.
Args:
batch: mapping from ports to values
goal_ports: optional output ports, defaults to output_ports of this module will be returned
Returns:
A mapping from goal ports to tensors.
"""
goal_ports = goal_ports or self.output_ports
inputs = [p.create_torch_variable(batch.get(p), gpu=torch.cuda.device_count() > 0) for p in self.input_ports]
outputs = self.prediction_module.forward(*inputs)
ret = {p: p.torch_to_numpy(t) for p, t in zip(self.output_ports, outputs) if p in goal_ports}
for p in goal_ports:
if p not in ret and p in batch:
ret[p] = batch[p]
return ret
评论列表
文章目录