lstm_ops.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def __call__(self, x, states_prev, scope=None):
    """Long short-term memory cell (LSTM)."""
    with vs.variable_scope(scope or type(self).__name__):
      x_shape = x.get_shape().with_rank(2)
      if not x_shape[1]:
        raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape))
      if len(states_prev) != 2:
        raise ValueError("Expecting states_prev to be a tuple with length 2.")
      input_size = x_shape[1]
      w = vs.get_variable("W", [input_size + self._num_units,
                                self._num_units * 4])
      b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]],
                          initializer=init_ops.constant_initializer(0.0))
      wci = vs.get_variable("wci", [self._num_units])
      wco = vs.get_variable("wco", [self._num_units])
      wcf = vs.get_variable("wcf", [self._num_units])
      (cs_prev, h_prev) = states_prev
      (_, cs, _, _, _, _, h) = _lstm_block_cell(x,
                                                cs_prev,
                                                h_prev,
                                                w,
                                                b,
                                                wci=wci,
                                                wco=wco,
                                                wcf=wcf,
                                                forget_bias=self._forget_bias,
                                                use_peephole=self._use_peephole)

      return (h, (cs, h))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号