rhn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:RecurrentHighwayNetworks 作者: julian121266 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
    current_state = state[0]
    noise_i = state[1]
    noise_h = state[2]
    for i in range(self.depth):
      with tf.variable_scope('h_'+str(i)):
        if i == 0:
          h = tf.tanh(linear([inputs * noise_i, current_state * noise_h], self._num_units, True))
        else:
          h = tf.tanh(linear([current_state * noise_h], self._num_units, True))
      with tf.variable_scope('t_'+str(i)):
        if i == 0:
          t = tf.sigmoid(linear([inputs * noise_i, current_state * noise_h], self._num_units, True, self.forget_bias))
        else:
          t = tf.sigmoid(linear([current_state * noise_h], self._num_units, True, self.forget_bias))
      current_state = (h - current_state)* t + current_state

    return current_state, [current_state, noise_i, noise_h]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号