rnn_cell.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:diversity_based_attention 作者: PrekshaNema25 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell (LSTM)."""
    with vs.variable_scope(scope or type(self).__name__): 

      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
        c, h = state
      else:
        c, h = array_ops.split(1, 2, state)
      concat = _linear([inputs, h], 5 * self._num_units, True)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate, g= distract_gate
      i, j, f, o, g = array_ops.split(1, 5, concat)

      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))


      eps = 1e-13
      temp = math_ops.div(math_ops.reduce_sum(math_ops.mul(c, new_c),1),math_ops.reduce_sum(math_ops.mul(c,c),1) + eps)

      m = array_ops.transpose(sigmoid(g))
      t1 = math_ops.mul(m , temp)
      t1 = array_ops.transpose(t1) 

      distract_c = new_c  -  c * t1

      new_h = self._activation(distract_c) * sigmoid(o)

      if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
      else:
        new_state = array_ops.concat(1, [new_c, new_h])

      return new_h, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号