def inference(images, batch_size, train):
"""Build the ocr model.
Args:
images: Images returned from distorted_inputs() or inputs().
Returns:
Logits.
"""
features, timesteps = convolutional_layers(images, batch_size, train)
logits = get_lstm_layers(features, timesteps, batch_size)
return logits, timesteps
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
# return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录