fused_rnn_cell.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def __call__(self,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    is_list = isinstance(inputs, list)
    if self._use_dynamic_rnn:
      if is_list:
        inputs = array_ops.pack(inputs)
      outputs, state = rnn.dynamic_rnn(
          self._cell,
          inputs,
          sequence_length=sequence_length,
          initial_state=initial_state,
          dtype=dtype,
          time_major=True,
          scope=scope)
      if is_list:
        # Convert outputs back to list
        outputs = array_ops.unpack(outputs)
    else:  # non-dynamic rnn
      if not is_list:
        inputs = array_ops.unpack(inputs)
      outputs, state = rnn.rnn(self._cell,
                               inputs,
                               initial_state=initial_state,
                               dtype=dtype,
                               sequence_length=sequence_length,
                               scope=scope)
      if not is_list:
        # Convert outputs back to tensor
        outputs = array_ops.pack(outputs)

    return outputs, state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号