dynamic_rnn_estimator.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def __init__(self,
               cell,
               target_column,
               optimizer,
               model_dir=None,
               config=None,
               gradient_clipping_norm=None,
               inputs_key='inputs',
               sequence_length_key='sequence_length',
               initial_state_key='initial_state',
               dtype=None,
               parallel_iterations=None,
               swap_memory=False,
               name=None):
    """Initialize `DynamicRNNEstimator`.

    Args:
      cell: an initialized `RNNCell` to be used in the RNN.
      target_column: an initialized `TargetColumn`, used to calculate loss and
        metrics.
      optimizer: an initialized `tensorflow.Optimizer`.
      model_dir: The directory in which to save and restore the model graph,
        parameters, etc.
      config: A `RunConfig` instance.
      gradient_clipping_norm: parameter used for gradient clipping. If `None`,
        then no clipping is performed.
      inputs_key: the key for input values in the features dict passed to
        `fit()`.
      sequence_length_key: the key for the sequence length tensor in the
        features dict passed to `fit()`.
      initial_state_key: the key for input values in the features dict passed to
        `fit()`.
      dtype: Parameter passed ot `dynamic_rnn`. The dtype of the state and
        output returned by `RNNCell`.
      parallel_iterations: Parameter passed ot `dynamic_rnn`. The number of
        iterations to run in parallel.
      swap_memory: Parameter passed ot `dynamic_rnn`. Transparently swap the
        tensors produced in forward inference but needed for back prop from GPU
        to CPU.
      name: Optional name for the `Estimator`.
    """
    super(_DynamicRNNEstimator, self).__init__(
        model_dir=model_dir, config=config)
    self._cell = cell
    self._target_column = target_column
    self._optimizer = optimizer
    self._gradient_clipping_norm = gradient_clipping_norm
    self._inputs_key = inputs_key
    self._sequence_length_key = sequence_length_key
    self._initial_state_key = initial_state_key
    self._dtype = dtype or dtypes.float32
    self._parallel_iterations = parallel_iterations
    self._swap_memory = swap_memory
    self._name = name or 'DynamicRnnEstimator'
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号