rnn.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def _apply(self, X, state=None, memory=None):
    # time_major: The shape format of the `inputs` and `outputs` Tensors.
    #   If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
    #   If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
    # ====== create attention if necessary ====== #
    cell = self.cell
    if self.bidirectional:
      cell_bw = self.cell_bw
    # create attention cell
    if self.attention:
      if not hasattr(self, "_cell_with_attention"):
        self._cell_with_attention = self.__attention_creator(
            cell, X=X, memory=memory)
        cell = self._cell_with_attention
      # bidirectional attention
      if self.bidirectional:
        if not hasattr(self, "_cell_with_attention_bw"):
          self._cell_with_attention_bw = self.__attention_creator(
              cell_bw, X=X, memory=memory)
        cell_bw = self._cell_with_attention_bw
    # ====== calling rnn_warpper ====== #
    ## Bidirectional
    if self.bidirectional:
      rnn_func = rnn.bidirectional_dynamic_rnn if self.dynamic \
          else rnn.static_bidirectional_rnn
      state_fw, state_bw = None, None
      if isinstance(state, (tuple, list)):
        state_fw = state[0]
        if len(state) > 1:
          state_bw = state[1]
      else:
        state_fw = state
      outputs = rnn_func(cell_fw=cell, cell_bw=cell_bw, inputs=X,
                         initial_state_fw=state_fw,
                         initial_state_bw=state_bw,
                         dtype=X.dtype.base_dtype)
    ## Unidirectional
    else:
      rnn_func = rnn.dynamic_rnn if self.dynamic else rnn.static_rnn
      outputs = rnn_func(cell, inputs=X, initial_state=state,
                         dtype=X.dtype.base_dtype)
    # ====== initialize cell ====== #
    if not self._is_initialized_variables:
      # initialize only once, everytime you call this, the values of
      # variables changed
      K.eval(tf.variables_initializer(self.variables))
      self._is_initialized_variables = True
      _infer_variable_role(self.variables)
    # ====== return ====== #
    if self.bidirectional: # concat outputs
      outputs = (tf.concat(outputs[0], axis=-1), outputs[1])
    if not self.return_states:
      return outputs[0]
    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号