ops.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:shuttleNet 作者: shiyemin 项目源码 文件源码
def adjust_max(start, stop, start_value, stop_value, name=None):
    with ops.name_scope(name, "AdjustMax",
                        [start, stop, name]) as name:
        global_step = tf.train.get_global_step()
        if global_step is not None:
            start = tf.convert_to_tensor(start, dtype=tf.int64)
            stop = tf.convert_to_tensor(stop, dtype=tf.int64)
            start_value = tf.convert_to_tensor(start_value, dtype=tf.float32)
            stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32)

            pred_fn_pairs = {}
            pred_fn_pairs[global_step <= start] = lambda: start_value
            pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay(
                                        start_value, global_step-start, stop-start,
                                        end_learning_rate=stop_value, power=1.0, cycle=False)
            default = lambda: stop_value
            return tf.case(pred_fn_pairs, default, exclusive=True)
        else:
            return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号