def __init__(self,
cell,
target_column,
optimizer,
model_dir=None,
config=None,
gradient_clipping_norm=None,
inputs_key='inputs',
sequence_length_key='sequence_length',
initial_state_key='initial_state',
dtype=None,
parallel_iterations=None,
swap_memory=False,
name=None):
"""Initialize `DynamicRNNEstimator`.
Args:
cell: an initialized `RNNCell` to be used in the RNN.
target_column: an initialized `TargetColumn`, used to calculate loss and
metrics.
optimizer: an initialized `tensorflow.Optimizer`.
model_dir: The directory in which to save and restore the model graph,
parameters, etc.
config: A `RunConfig` instance.
gradient_clipping_norm: parameter used for gradient clipping. If `None`,
then no clipping is performed.
inputs_key: the key for input values in the features dict passed to
`fit()`.
sequence_length_key: the key for the sequence length tensor in the
features dict passed to `fit()`.
initial_state_key: the key for input values in the features dict passed to
`fit()`.
dtype: Parameter passed ot `dynamic_rnn`. The dtype of the state and
output returned by `RNNCell`.
parallel_iterations: Parameter passed ot `dynamic_rnn`. The number of
iterations to run in parallel.
swap_memory: Parameter passed ot `dynamic_rnn`. Transparently swap the
tensors produced in forward inference but needed for back prop from GPU
to CPU.
name: Optional name for the `Estimator`.
"""
super(_DynamicRNNEstimator, self).__init__(
model_dir=model_dir, config=config)
self._cell = cell
self._target_column = target_column
self._optimizer = optimizer
self._gradient_clipping_norm = gradient_clipping_norm
self._inputs_key = inputs_key
self._sequence_length_key = sequence_length_key
self._initial_state_key = initial_state_key
self._dtype = dtype or dtypes.float32
self._parallel_iterations = parallel_iterations
self._swap_memory = swap_memory
self._name = name or 'DynamicRnnEstimator'
评论列表
文章目录