tree_encoder.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def __call__(self, left_state, right_state, extra_input=None):
        with tf.variable_scope('TreeLSTM'):
            c1, h1 = left_state
            c2, h2 = right_state

            if extra_input is not None:
                input_concat = tf.concat((extra_input, h1, h2), axis=1)
            else:
                input_concat = tf.concat((h1, h2), axis=1)
            concat = tf.layers.dense(input_concat, 5 * self._num_cells)
            i, f1, f2, o, g = tf.split(concat, 5, axis=1)
            i = tf.sigmoid(i)
            f1 = tf.sigmoid(f1)
            f2 = tf.sigmoid(f2)
            o = tf.sigmoid(o)
            g = tf.tanh(g)

            cnew = f1 * c1 + f2 * c2 + i * g
            hnew = o * cnew

            newstate = LSTMStateTuple(c=cnew, h=hnew)
            return hnew, newstate
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号