def build_augmentation_graph(self):
num_targets = len(self.dset.targets)
# Outputs and queue of the data augmentation graph
train_queue = tf.RandomShuffleQueue(
self.config["queueing"]["random_queue_size"],
self.config["queueing"]["min_size"],
[tf.float32] + [tf.int32] * num_targets,
shapes=self.input_variables["shapes"]["crops"]
)
augmented_op = imaging.augmentations.aument_multiple(
self.input_variables["labeled_crops"][0],
self.config["queueing"]["augmentation_workers"]
)
train_enqueue_op = train_queue.enqueue_many(
[augmented_op] +
self.input_variables["labeled_crops"][1:]
)
train_inputs = train_queue.dequeue() #_many(config["training"]["minibatch"])
self.train_variables = {
"image_batch":train_inputs[0],
"queue":train_queue,
"enqueue_op":train_enqueue_op
}
for i in range(num_targets):
tname = "target_" + str(i)
tgt = self.dset.targets[i]
self.train_variables[tname] = tf.one_hot(train_inputs[i+1], tgt.shape[1])
#################################################
## START TRAINING QUEUES
#################################################
评论列表
文章目录