def __init__(self, model, loss_fn, data_path, log_dir, graph, input_reader):
"""Initialize a `TrainEvalBase` object.
Args:
model: an instance of a subclass of the `ModelBase` class (defined in
`model_base.py`).
loss_fn: a tensorflow op, a loss function for training a model. See:
https://www.tensorflow.org/code/tensorflow/contrib/losses/python/losses/loss_ops.py
for a list of available loss functions.
data_path: a string, path to files of tf.Example protos containing data.
log_dir: a string, logging directory.
graph: a tensorflow computation graph.
input_reader: an instance of a subclass of the `InputReaderBase` class
(defined in `input_reader_base.py`).
"""
self._data_path = data_path
self._log_dir = log_dir
self._train_log_dir = os.path.join(self._log_dir, "train")
self._eval_log_dir = os.path.join(self._log_dir, "eval")
self._model = model
self._loss_fn = loss_fn
self._graph = graph
self._input_reader = input_reader
self._summary_ops = []
评论列表
文章目录