tf_Session.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def _do_run(self, handle, target_list, fetch_list, feed_dict,
              options, run_metadata):
    """Runs a step based on the given fetches and feeds.

    Args:
      handle: a handle for partial_run. None if this is just a call to run().
      target_list: A list of operations to be run, but not fetched.
      fetch_list: A list of tensors to be fetched.
      feed_dict: A dictionary that maps tensors to numpy ndarrays.
      options: A (pointer to a) [`RunOptions`] protocol buffer, or None
      run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.

    Raises:
      tf.errors.OpError: Or one of its subclasses on error.
    """
    if self._created_with_new_api:
      # pylint: disable=protected-access
      feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
      fetches = [t._as_tf_output() for t in fetch_list]
      targets = [op._c_op for op in target_list]
      # pylint: enable=protected-access
    else:
      feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
      fetches = _name_list(fetch_list)
      targets = _name_list(target_list)

    def _run_fn(session, feed_dict, fetch_list, target_list, options,
                run_metadata):
      # Ensure any changes to the graph are reflected in the runtime.
      self._extend_graph()
      with errors.raise_exception_on_not_ok_status() as status:
        if self._created_with_new_api:
          return tf_session.TF_SessionRun_wrapper(
              session, options, feed_dict, fetch_list, target_list,
              run_metadata, status)
        else:
          return tf_session.TF_Run(session, options,
                                   feed_dict, fetch_list, target_list,
                                   status, run_metadata)

    def _prun_fn(session, handle, feed_dict, fetch_list):
      assert not self._created_with_new_api, ('Partial runs don\'t work with '
                                              'C API')
      if target_list:
        raise RuntimeError('partial_run() requires empty target_list.')
      with errors.raise_exception_on_not_ok_status() as status:
        return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
                                  status)

    if handle is None:
      return self._do_call(_run_fn, self._session, feeds, fetches, targets,
                           options, run_metadata)
    else:
      return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号