def __init__(self, inputs, sequence_length, embedding, sampling_probability,
time_major=False, seed=None, scheduling_seed=None, name=None):
"""Initializer.
Args:
inputs: A (structure of) input tensors.
sequence_length: An int32 vector tensor.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
sampling_probability: A 0D `float32` tensor: the probability of sampling
categorically from the output ids instead of reading directly from the
inputs.
time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major.
seed: The sampling seed.
scheduling_seed: The schedule decision rule sampling seed.
name: Name scope for any created operations.
Raises:
ValueError: if `sampling_probability` is not a scalar or vector.
"""
with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper",
[embedding, sampling_probability]):
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._sampling_probability = ops.convert_to_tensor(
sampling_probability, name="sampling_probability")
if self._sampling_probability.get_shape().ndims not in (0, 1):
raise ValueError(
"sampling_probability must be either a scalar or a vector. "
"saw shape: %s" % (self._sampling_probability.get_shape()))
self._seed = seed
self._scheduling_seed = scheduling_seed
super(ScheduledEmbeddingTrainingHelper, self).__init__(
inputs=inputs,
sequence_length=sequence_length,
time_major=time_major,
name=name)
评论列表
文章目录