def get_valid_logits_and_labels(annotation_batch_tensor,
logits_batch_tensor,
class_labels):
"""Returns two tensors of size (num_valid_entries, num_classes).
The function converts annotation batch tensor input of the size
(batch_size, height, width) into label tensor (batch_size, height,
width, num_classes) and then selects only valid entries, resulting
in tensor of the size (num_valid_entries, num_classes). The function
also returns the tensor with corresponding valid entries in the logits
tensor. Overall, two tensors of the same sizes are returned and later on
can be used as an input into tf.softmax_cross_entropy_with_logits() to
get the cross entropy error for each entry.
Parameters
----------
annotation_batch_tensor : Tensor of size (batch_size, width, height)
Tensor with class labels for each batch
logits_batch_tensor : Tensor of size (batch_size, width, height, num_classes)
Tensor with logits. Usually can be achived after inference of fcn network.
class_labels : list of ints
List that contains the numbers that represent classes. Last
value in the list should represent the number that was used
for masking out.
Returns
-------
(valid_labels_batch_tensor, valid_logits_batch_tensor) : Two Tensors of size (num_valid_eintries, num_classes).
Tensors that represent valid labels and logits.
"""
annotation_batch_tensor = tf.py_func(sample, [annotation_batch_tensor], tf.int32)
labels_batch_tensor = get_labels_from_annotation_batch(annotation_batch_tensor=annotation_batch_tensor,
class_labels=class_labels)
valid_batch_indices = get_valid_entries_indices_from_annotation_batch(annotation_batch_tensor=annotation_batch_tensor,
class_labels=class_labels)
valid_labels_batch_tensor = tf.gather_nd(params=labels_batch_tensor, indices=valid_batch_indices)
valid_logits_batch_tensor = tf.gather_nd(params=logits_batch_tensor, indices=valid_batch_indices)
return valid_labels_batch_tensor, valid_logits_batch_tensor
评论列表
文章目录