train_eval_base.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:easy-tensorflow 作者: khanhptnk 项目源码 文件源码
def __init__(self, model, loss_fn, data_path, log_dir, graph, input_reader):
    """Initialize a `TrainEvalBase` object.
      Args:
        model: an instance of a subclass of the `ModelBase` class (defined in
          `model_base.py`).
        loss_fn: a tensorflow op, a loss function for training a model. See:
            https://www.tensorflow.org/code/tensorflow/contrib/losses/python/losses/loss_ops.py
          for a list of available loss functions.
        data_path: a string, path to files of tf.Example protos containing data.
        log_dir: a string, logging directory.
        graph: a tensorflow computation graph.
        input_reader: an instance of a subclass of the `InputReaderBase` class
          (defined in `input_reader_base.py`).
    """
    self._data_path = data_path
    self._log_dir = log_dir
    self._train_log_dir = os.path.join(self._log_dir, "train")
    self._eval_log_dir = os.path.join(self._log_dir, "eval")

    self._model = model
    self._loss_fn = loss_fn
    self._graph = graph
    self._input_reader = input_reader

    self._summary_ops = []
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号