def get(self, images, num_classes, train_phase=False, l2_penalty=0.0):
""" define the model with its inputs.
Use this function to define the model in training and when exporting the model
in the protobuf format.
Args:
images: model input
num_classes: number of classes to predict
train_phase: set it to True when defining the model, during train
l2_penalty: float value, weight decay (l2) penalty
Returns:
is_training_: tf.bool placeholder enable/disable training ops at run time
logits: the model output
"""
is_training_ = tf.placeholder_with_default(
False, shape=(), name="is_training_")
# build a graph that computes the logits predictions from the images
logits = self._inference(images, num_classes, is_training_, train_phase,
l2_penalty)
return is_training_, logits
评论列表
文章目录