level2_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Skeleton-key 作者: feiyu1990 项目源码 文件源码
def _lstm(self, input_h, input_c, input_x, reuse=False):
        with tf.variable_scope('level2_lstm', reuse=reuse):
            w_i2h_ = np.transpose(self.model_load['/core/i2h_1/weight'][:], (1, 0))
            b_i2h_ = self.model_load['/core/i2h_1/bias'][:]
            w_h2h_ = np.transpose(self.model_load['/core/h2h_1/weight'][:], (1, 0))
            b_h2h_ = self.model_load['/core/h2h_1/bias'][:]

            w_i2h = tf.get_variable('w_i2h', initializer=w_i2h_)
            b_i2h = tf.get_variable('b_i2h', initializer=b_i2h_)
            w_h2h = tf.get_variable('w_h2h', initializer=w_h2h_)
            b_h2h = tf.get_variable('b_h2h', initializer=b_h2h_)

            input_x = tf.cast(input_x, tf.float32)
            i2h = tf.matmul(input_x, w_i2h) + b_i2h
            h2h = tf.matmul(input_h, w_h2h) + b_h2h
            all_input_sums = i2h + h2h
            reshaped = tf.reshape(all_input_sums, [-1, 4, self.H])
            n1, n2, n3, n4 = tf.unstack(reshaped, axis=1)
            in_gate = tf.sigmoid(n1)
            forget_gate = tf.sigmoid(n2)
            out_gate = tf.sigmoid(n3)
            in_transform = tf.tanh(n4)
            c = tf.multiply(forget_gate, input_c) + tf.multiply(in_gate, in_transform)
            h = tf.multiply(out_gate, tf.tanh(c))
            return c, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号