tensorflow.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def __call__(self, batch: Mapping[TensorPort, np.ndarray],
                 goal_ports: List[TensorPort] = None) -> Mapping[TensorPort, np.ndarray]:
        """Runs a batch and returns values/outputs for specified goal ports.
        Args:
            batch: mapping from ports to values
            goal_ports: optional output ports, defaults to output_ports of this module will be returned

        Returns:
            A mapping from goal ports to tensors.
        """
        goal_ports = goal_ports or self.output_ports
        feed_dict = self.convert_to_feed_dict(batch)
        goal_tensors = {p: self.tensors[p] for p in goal_ports
                        if p in self.output_ports or p in self.training_output_ports}
        outputs = self.tf_session.run(goal_tensors, feed_dict)

        for p in goal_ports:
            if p not in outputs and p in batch:
                outputs[p] = batch[p]

        return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号