reinforce.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:TensorFlow-in-a-Nutshell 作者: camrongodbout 项目源码 文件源码
def split_apply_merge(inp, partitions, fns):
  """Split input according to partitions.  Pass results through fns and merge.
  Args:
    inp: the input vector
    partitions: tensor of same length as input vector, having values 0, 1
    fns: the two functions.
  Returns:
    the vector routed, where routed[i] = fns[partitions[i]](inp[i])
  """
  new_inputs = tf.dynamic_partition(inp, partitions, len(fns))
  new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)]
  new_indices = tf.dynamic_partition(
      tf.range(0, inp.get_shape()[0]), partitions, len(fns))
  return tf.dynamic_stitch(new_indices, new_outputs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号