lstm_ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
                 sequence_length):
    """Run this LSTM on inputs, starting from the given state.

    Args:
      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
      initial_cell_state: initial value for cell state, shape `[batch_size,
        self._num_units]`
      initial_output: initial value of cell output, shape `[batch_size,
        self._num_units]`
      dtype: The data type for the initial state and expected output.
      sequence_length: Specifies the length of each sequence in inputs. An
        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
        time_len)` or None.

    Returns:
      A pair containing:

      - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
                         output_size]`
      - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
                    output_size]`
    """

    inputs_shape = inputs.get_shape().with_rank(3)
    time_len = inputs_shape[0].value
    if time_len is None:
      time_len = array_ops.shape(inputs)[0]
    input_size = inputs_shape[2].value
    w = vs.get_variable(
        "W_0", [input_size + self._num_units, self._num_units * 4], dtype=dtype)
    b = vs.get_variable(
        "B", [w.get_shape().with_rank(2)[1]],
        initializer=init_ops.constant_initializer(0.0),
        dtype=dtype)
    if self._use_peephole:
      wci = vs.get_variable("W_I_diag", [self._num_units], dtype=dtype)
      wco = vs.get_variable("W_O_diag", [self._num_units], dtype=dtype)
      wcf = vs.get_variable("W_F_diag", [self._num_units], dtype=dtype)
    else:
      wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype)

    if sequence_length is None:
      max_seq_len = time_len
    else:
      max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length))

    _, cs, _, _, _, _, h = _lstm_ops_so.block_lstm(
        seq_len_max=max_seq_len,
        x=inputs,
        cs_prev=initial_cell_state,
        h_prev=initial_output,
        w=w,
        wci=wci,
        wco=wco,
        wcf=wcf,
        b=b,
        forget_bias=self._forget_bias,
        cell_clip=self._cell_clip,
        use_peephole=self._use_peephole)
    return cs, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号