ntm.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:Neural-Turing-Machine 作者: yeoedward 项目源码 文件源码
def deserialize(self, state):
    # Deserialize state from previous timestep.
    M0 = tf.slice(
      state,
      [0, 0],
      [-1, self.mem_nrows * self.mem_ncols],
    )
    M0 = tf.reshape(M0, [-1, self.mem_nrows, self.mem_ncols])

    state_idx = self.mem_nrows * self.mem_ncols

    # Deserialize read weights from previous time step.
    read_w0s = []
    for i in xrange(self.n_heads):
      # Number of weights == Rows of memory matrix
      w0 = tf.slice(state, [0, state_idx], [-1, self.mem_nrows])
      read_w0s.append(w0)
      state_idx += self.mem_nrows

    # Do the same for write heads.
    write_w0s = []
    for _ in xrange(self.n_heads):
      w0 = tf.slice(state, [0, state_idx], [-1, self.mem_nrows])
      write_w0s.append(w0)
      state_idx += self.mem_nrows

    tf.Assert(
      tf.equal(state_idx, tf.shape(state)[1]),
      [tf.shape(state)],
    )

    return M0, write_w0s, read_w0s
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号