ocr.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:tf-cnn-lstm-ocr-captcha 作者: Luonic 项目源码 文件源码
def train_simple(total_loss, global_step):
  with tf.variable_scope('train_op'):
    # Variables that affect learning rate.
    num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
    decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                    global_step,
                                    decay_steps,
                                    LEARNING_RATE_DECAY_FACTOR,
                                    staircase=True)
    tf.summary.scalar('learning_rate', lr)

    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(update_ops):
        # opt = tf.train.MomentumOptimizer(lr, 0.9).minimize(total_loss, global_step=global_step)
    opt = tf.train.AdamOptimizer(lr).minimize(total_loss, global_step=global_step)

    tf.summary.scalar(total_loss.op.name + ' (raw)', total_loss)
  return opt, lr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号