torch.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def __call__(self, batch: Mapping[TensorPort, np.ndarray],
                 goal_ports: List[TensorPort] = None) -> Mapping[TensorPort, np.ndarray]:
        """Runs a batch and returns values/outputs for specified goal ports.
        Args:
            batch: mapping from ports to values
            goal_ports: optional output ports, defaults to output_ports of this module will be returned

        Returns:
            A mapping from goal ports to tensors.
        """
        goal_ports = goal_ports or self.output_ports
        inputs = [p.create_torch_variable(batch.get(p), gpu=torch.cuda.device_count() > 0) for p in self.input_ports]
        outputs = self.prediction_module.forward(*inputs)
        ret = {p: p.torch_to_numpy(t) for p, t in zip(self.output_ports, outputs) if p in goal_ports}
        for p in goal_ports:
            if p not in ret and p in batch:
                ret[p] = batch[p]
        return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号