functions.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:neuralmonkey 作者: ufal 项目源码 文件源码
def piecewise_function(param, values, changepoints, name=None,
                       dtype=tf.float32):
    """Compute a piecewise function.

    Arguments:
        param: The function parameter.
        values: List of function values (numbers or tensors).
        changepoints: Sorted list of points where the function changes from
            one value to the next. Must be one item shorter than `values`.
    """

    if len(changepoints) != len(values) - 1:
        raise ValueError("changepoints has length {}, expected {} (values "
                         "has length {})".format(len(changepoints),
                                                 len(values) - 1,
                                                 len(values)))

    with tf.name_scope(name, "PiecewiseFunction",
                       [param, values, changepoints]) as s_name:
        values = [tf.convert_to_tensor(y, dtype=dtype) for y in values]
        # this is a trick to make each lambda return a different y:
        lambdas = [lambda y=y: y for y in values]
        predicates = [tf.less(param, x) for x in changepoints]
        return tf.case(list(zip(predicates, lambdas[:-1])), lambdas[-1],
                       name=s_name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号