rnn_cell.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell with attention (LSTMA)."""
    with vs.variable_scope(scope or type(self).__name__):
      if self._state_is_tuple:
        state, attns, attn_states = state
      else:
        states = state
        state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
        attns = array_ops.slice(
            states, [0, self._cell.state_size], [-1, self._attn_size])
        attn_states = array_ops.slice(
            states, [0, self._cell.state_size + self._attn_size],
            [-1, self._attn_size * self._attn_length])
      attn_states = array_ops.reshape(attn_states,
                                      [-1, self._attn_length, self._attn_size])
      input_size = self._input_size
      if input_size is None:
        input_size = inputs.get_shape().as_list()[1]
      inputs = _linear([inputs, attns], input_size, True)
      lstm_output, new_state = self._cell(inputs, state)
      if self._state_is_tuple:
        new_state_cat = array_ops.concat(1, nest.flatten(new_state))
      else:
        new_state_cat = new_state
      new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
      with vs.variable_scope("AttnOutputProjection"):
        output = _linear([lstm_output, new_attns], self._attn_size, True)
      new_attn_states = array_ops.concat(1, [new_attn_states,
                                             array_ops.expand_dims(output, 1)])
      new_attn_states = array_ops.reshape(
          new_attn_states, [-1, self._attn_length * self._attn_size])
      new_state = (new_state, new_attns, new_attn_states)
      if not self._state_is_tuple:
        new_state = array_ops.concat(1, list(new_state))
      return output, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号