def _process_scale(self, scale, event_ndims):
"""Helper to __init__ which gets scale in batch-ready form.
This function expands dimensions of `scale` according to the following
table:
event_ndims
scale.ndims 0 1
0 [1]+S+[1,1] "silent error"
1 [ ]+S+[1,1] "silent error"
2 [ ]+S+[1,1] [1]+S+[ ]
3 [ ]+S+[1,1] [ ]+S+[ ]
... (same) (same)
The idea is that we want to convert `scale` into something which can always
work for, say, the left-hand argument of `batch_matmul`.
Args:
scale: `Tensor`.
event_ndims: `Tensor` (0D, `int32`).
Returns:
scale: `Tensor` with dims expanded according to [above] table.
batch_ndims: `Tensor` (0D, `int32`). The ndims of the `batch` portion.
"""
ndims = array_ops.rank(scale)
left = math_ops.select(
math_ops.reduce_any([
math_ops.reduce_all([
math_ops.equal(ndims, 0),
math_ops.equal(event_ndims, 0)
]),
math_ops.reduce_all([
math_ops.equal(ndims, 2),
math_ops.equal(event_ndims, 1)
])]), 1, 0)
right = math_ops.select(math_ops.equal(event_ndims, 0), 2, 0)
pad = array_ops.concat(0, (
array_ops.ones([left], dtype=dtypes.int32),
array_ops.shape(scale),
array_ops.ones([right], dtype=dtypes.int32)))
scale = array_ops.reshape(scale, pad)
batch_ndims = ndims - 2 + right
return scale, batch_ndims
python类reduce_all()的实例源码
def _process_scale(self, scale, event_ndims):
"""Helper to __init__ which gets scale in batch-ready form.
This function expands dimensions of `scale` according to the following
table:
event_ndims
scale.ndims 0 1
0 [1]+S+[1,1] "silent error"
1 [ ]+S+[1,1] "silent error"
2 [ ]+S+[1,1] [1]+S+[ ]
3 [ ]+S+[1,1] [ ]+S+[ ]
... (same) (same)
The idea is that we want to convert `scale` into something which can always
work for, say, the left-hand argument of `batch_matmul`.
Args:
scale: `Tensor`.
event_ndims: `Tensor` (0D, `int32`).
Returns:
scale: `Tensor` with dims expanded according to [above] table.
batch_ndims: `Tensor` (0D, `int32`). The ndims of the `batch` portion.
"""
ndims = array_ops.rank(scale)
left = math_ops.select(
math_ops.reduce_any([
math_ops.reduce_all([
math_ops.equal(ndims, 0),
math_ops.equal(event_ndims, 0)
]),
math_ops.reduce_all([
math_ops.equal(ndims, 2),
math_ops.equal(event_ndims, 1)
])]), 1, 0)
right = math_ops.select(math_ops.equal(event_ndims, 0), 2, 0)
pad = array_ops.concat(0, (
array_ops.ones([left], dtype=dtypes.int32),
array_ops.shape(scale),
array_ops.ones([right], dtype=dtypes.int32)))
scale = array_ops.reshape(scale, pad)
batch_ndims = ndims - 2 + right
# For safety, explicitly zero-out the upper triangular part.
scale = array_ops.matrix_band_part(scale, -1, 0)
if self.validate_args:
# matrix_band_part will fail if scale is not at least rank 2.
shape = array_ops.shape(scale)
assert_square = check_ops.assert_equal(
shape[-2], shape[-1],
message="Input must be a (batch of) square matrix.")
# Assuming lower-triangular means we only need check diag != 0.
diag = array_ops.matrix_diag_part(scale)
is_non_singular = math_ops.logical_not(
math_ops.reduce_any(
math_ops.equal(diag, ops.convert_to_tensor(0, dtype=diag.dtype))))
assert_non_singular = control_flow_ops.Assert(
is_non_singular, ["Singular matrix encountered", diag])
scale = control_flow_ops.with_dependencies(
[assert_square, assert_non_singular], scale)
return scale, batch_ndims
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledOutputTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
"""Concatenate outputs with auxiliary inputs, if they exist."""
if self._auxiliary_input_tas is None:
return outputs_
next_time = time + 1
auxiliary_inputs = nest.map_structure(
lambda ta: ta.read(next_time), self._auxiliary_input_tas)
if indices is not None:
auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
return nest.map_structure(
lambda x, y: array_ops.concat((x, y), -1),
outputs_, auxiliary_inputs)
if self._next_input_layer is None:
return array_ops.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs)
where_sampling = math_ops.cast(
array_ops.where(sample_ids), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)