cells.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:emoatt 作者: epochx 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
    if isinstance(self.state_size, tuple) != isinstance(self._zoneout_prob, tuple):
      raise TypeError("Subdivided states need subdivided zoneouts.")
    if isinstance(self.state_size, tuple) and len(tuple(self.state_size)) != len(tuple(self._zoneout_prob)):
      raise ValueError("State and zoneout need equally many parts.")
    output, new_state = self._cell(inputs, state, scope)
    if isinstance(self.state_size, tuple):
      if self.is_training:
        new_state = self._tuple([(1 - state_part_zoneout_prob) * dropout(
          new_state_part - state_part, (1 - state_part_zoneout_prob)) + state_part
                          for new_state_part, state_part, state_part_zoneout_prob in
                          zip(new_state, state, self._zoneout_prob)])
      else:
        new_state = self._tuple([state_part_zoneout_prob * state_part + (1 - state_part_zoneout_prob) * new_state_part
                          for new_state_part, state_part, state_part_zoneout_prob in
                          zip(new_state, state, self._zoneout_prob)])
    else:
      if self.is_training:
        new_state = (1 - state_part_zoneout_prob) * dropout(
          new_state_part - state_part, (1 - state_part_zoneout_prob)) + state_part
      else:
        new_state = state_part_zoneout_prob * state_part + (1 - state_part_zoneout_prob) * new_state_part
    return output, new_state

# # Wrap your cells like this
# cell = ZoneoutWrapper(tf.nn.rnn_cell.LSTMCell(hidden_units, initializer=random_uniform(), state_is_tuple=True),
# zoneout_prob=(z_prob_cells, z_prob_states))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号