def _process_scale(self, scale, event_ndims):
"""Helper to __init__ which gets scale in batch-ready form.
This function expands dimensions of `scale` according to the following
table:
event_ndims
scale.ndims 0 1
0 [1]+S+[1,1] "silent error"
1 [ ]+S+[1,1] "silent error"
2 [ ]+S+[1,1] [1]+S+[ ]
3 [ ]+S+[1,1] [ ]+S+[ ]
... (same) (same)
The idea is that we want to convert `scale` into something which can always
work for, say, the left-hand argument of `batch_matmul`.
Args:
scale: `Tensor`.
event_ndims: `Tensor` (0D, `int32`).
Returns:
scale: `Tensor` with dims expanded according to [above] table.
batch_ndims: `Tensor` (0D, `int32`). The ndims of the `batch` portion.
"""
ndims = array_ops.rank(scale)
left = math_ops.select(
math_ops.reduce_any([
math_ops.reduce_all([
math_ops.equal(ndims, 0),
math_ops.equal(event_ndims, 0)
]),
math_ops.reduce_all([
math_ops.equal(ndims, 2),
math_ops.equal(event_ndims, 1)
])]), 1, 0)
right = math_ops.select(math_ops.equal(event_ndims, 0), 2, 0)
pad = array_ops.concat(0, (
array_ops.ones([left], dtype=dtypes.int32),
array_ops.shape(scale),
array_ops.ones([right], dtype=dtypes.int32)))
scale = array_ops.reshape(scale, pad)
batch_ndims = ndims - 2 + right
return scale, batch_ndims
评论列表
文章目录