model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:cmcl 作者: chhwang 项目源码 文件源码
def variable_scheduler(var_list, pivot_list, gstep, name=None):
    """Schedule variable according to the global step.
       e.g. var_list = [0.1, 0.01, 0.001], pivot_list = [0, 1000, 2000] then
         0    <= gstep < 1000 --> return 0.1
         1000 <= gstep < 2000 --> return 0.01
         2000 <= gstep        --> return 0.001
    Args:
      var_list: List of variables to return.
      pivot_list: List of pivots when to change the variable.
      gstep: Global step (# of batches trained so far).
      name(Optional): Name of the operation.
    """
    assert(len(var_list) == len(pivot_list))
    if len(var_list) == 1:
        return tf.constant(var_list[0])

    def between(x, a, b):
        return tf.logical_and(tf.greater_equal(x, a), tf.less(x, b))

    # This class is necessary to declare constant lambda expressions
    class temp(object):
        def __init__(self, var):
            self.func = lambda: tf.constant(var)

    gstep = tf.to_int32(gstep)
    conds = {}
    for idx in range(len(pivot_list)-1):
        min_val = tf.constant(pivot_list[idx], tf.int32)
        max_val = tf.constant(pivot_list[idx+1], tf.int32)
        conds[between(gstep, min_val, max_val)] = temp(var_list[idx]).func
    return tf.case(conds, default=temp(var_list[-1]).func, exclusive=True, name=name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号