def train(self, fetches=None, feed_dict=None, use_lock=False): # pylint: disable=arguments-differ
""" Train the model with the data provided
Parameters
----------
fetches : tuple, list
a sequence of `tf.Operation` and/or `tf.Tensor` to calculate
feed_dict : dict
input data, where key is a placeholder name and value is a numpy value
use_lock : bool
if True, the whole train step is locked, thus allowing for multithreading.
Returns
-------
Calculated values of tensors in `fetches` in the same structure
See also
--------
`Tensorflow Session run <https://www.tensorflow.org/api_docs/python/tf/Session#run>`_
"""
with self.graph.as_default():
_feed_dict = self._fill_feed_dict(feed_dict, is_training=True)
if fetches is None:
_fetches = tuple()
else:
_fetches = self._fill_fetches(fetches, default=None)
if use_lock:
self._train_lock.acquire()
_all_fetches = []
if self.train_step:
_all_fetches += [self.train_step]
if _fetches is not None:
_all_fetches += [_fetches]
if len(_all_fetches) > 0:
_, output = self.session.run(_all_fetches, feed_dict=_feed_dict)
else:
output = None
if use_lock:
self._train_lock.release()
return self._fill_output(output, _fetches)
评论列表
文章目录