recurrent.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:keras_bn_library 作者: bnsnapper 项目源码 文件源码
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        self.input_dim = input_shape[2]

        self.W = self.init((self.output_dim, 4 * self.input_dim),
                           name='{}_W'.format(self.name))
        self.U = self.inner_init((self.input_dim, 4 * self.input_dim),
                                 name='{}_U'.format(self.name))
        self.b = K.variable(np.hstack((np.zeros(self.input_dim),
                                       K.get_value(self.forget_bias_init((self.input_dim,))),
                                       np.zeros(self.input_dim),
                                       np.zeros(self.input_dim))),
                            name='{}_b'.format(self.name))

        self.A = self.init((self.input_dim, self.output_dim),
                            name='{}_A'.format(self.name))
        self.ba = K.zeros((self.output_dim,), name='{}_ba'.format(self.name))


        self.trainable_weights = [self.W, self.U, self.b, self.A, self.ba]

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号