superhighway.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:Chinese-QA 作者: distantJing 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or "SHCell"):
            a_size = 1 if self._scalar else self._state_size
            h, u = tf.split(1, 2, inputs)
            if self._logit_func == 'mul_linear':
                args = [h * u, state * u]
                a = tf.nn.sigmoid(linear(args, a_size, True))
            elif self._logit_func == 'linear':
                args = [h, u, state]
                a = tf.nn.sigmoid(linear(args, a_size, True))
            elif self._logit_func == 'tri_linear':
                args = [h, u, state, h * u, state * u]
                a = tf.nn.sigmoid(linear(args, a_size, True))
            elif self._logit_func == 'double':
                args = [h, u, state]
                a = tf.nn.sigmoid(linear(tf.tanh(linear(args, a_size, True)), self._state_size, True))

            else:
                raise Exception()
            new_state = a * state + (1 - a) * h
            outputs = state
            return outputs, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号