def piecewise_function(param, values, changepoints, name=None,
dtype=tf.float32):
"""Compute a piecewise function.
Arguments:
param: The function parameter.
values: List of function values (numbers or tensors).
changepoints: Sorted list of points where the function changes from
one value to the next. Must be one item shorter than `values`.
"""
if len(changepoints) != len(values) - 1:
raise ValueError("changepoints has length {}, expected {} (values "
"has length {})".format(len(changepoints),
len(values) - 1,
len(values)))
with tf.name_scope(name, "PiecewiseFunction",
[param, values, changepoints]) as s_name:
values = [tf.convert_to_tensor(y, dtype=dtype) for y in values]
# this is a trick to make each lambda return a different y:
lambdas = [lambda y=y: y for y in values]
predicates = [tf.less(param, x) for x in changepoints]
return tf.case(list(zip(predicates, lambdas[:-1])), lambdas[-1],
name=s_name)
评论列表
文章目录