def __init__(self, config, input_queue=None, predict_tower=None):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars()
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if input_queue is None:
self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
# by default, use the first training gpu for prediction
self.predict_tower = predict_tower or [0]
self.dequed_inputs = None
评论列表
文章目录