FastWeightsRNN.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:AssociativeRetrieval 作者: jxwufan 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
    state, fast_weights = state
    with vs.variable_scope(scope or type(self).__name__) as scope:
      """Compute Wh(t) + Cx(t)"""
      linear = self._fwlinear([state, inputs], self._num_units, False)
      """Compute h_0(t+1) = f(Wh(t) + Cx(t))"""
      if not self._reuse_norm:
        h = self._activation(self._norm(linear, scope="Norm0"))
      else:
        h = self._activation(self._norm(linear))
      h = self._vector2matrix(h)
      linear = self._vector2matrix(linear)
      for i in range(self._S):
        """
        Compute h_{s+1}(t+1) = f([Wh(t) + Cx(t)] + A(t) h_s(t+1)), S times.
        See Eqn (2) in the paper.
        """
        if not self._reuse_norm:
          h = self._activation(self._norm(linear +
                                          math_ops.batch_matmul(fast_weights, h), scope="Norm%d" % (i + 1)))
        else:
          h = self._activation(self._norm(linear +
                                          math_ops.batch_matmul(fast_weights, h)))

      """
      Compute A(t+1)  according to Eqn (4)
      """
      state = self._vector2matrix(state)
      new_fast_weights = self._lambda * fast_weights + self._eta * math_ops.batch_matmul(state, state, adj_y=True)

      h = self._matrix2vector(h)

      return h, (h, new_fast_weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号