custom_layers.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:urnn 作者: stwisdom 项目源码 文件源码
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        if self.stateful:
            self.reset_states()
        else:
            # initial states: all-zero tensor of shape (output_dim)
            self.states = [None]
        input_dim = input_shape[2]
        self.input_dim = input_dim

        self.W = self.init((input_dim, self.output_dim),
                           name='{}_W'.format(self.name))
        #self.b = K.zeros((self.N,), name='{}_b'.format(self.name))
        self.b = initializations.uniform((self.N,),scale=0.01,name='{}_b'.format(self.name))
        self.baug=K.tile(self.b,[2])

        h0 = self.h0_mean+initializations.uniform((2*self.N,),scale=0.01).get_value()
        self.h0 = K.variable(h0,name='{}_h0'.format(self.name))

        if ('full' in self.unitary_impl):   
            # we're using a full unitary recurrence matrix

            if (self.inner_init=='svd'):
                # use SVD to initialize U
                self.U = unitary_svd_init((self.N, self.N),name='{}_U'.format(self.name))
            elif (self.inner_init=='ASB2016'):
                # use parameterization of [ASB2016] to initialize U
                Uaug,_,_,_ = unitary_ASB2016_init((self.N,self.N))
                Uaug=Uaug.eval()
                self.U=K.variable(np.concatenate((Uaug[:self.N,:self.N],Uaug[:self.N,self.N:]),axis=0),name='{}_U'.format(self.name))

            self.Uaug=augRight(self.U,module=K)

        elif (self.unitary_impl=='ASB2016'):
            # we're using the parameterization of [Arjovsky, Shah, Bengio 2016]
            self.Uaug,self.theta,self.reflection,_ = unitary_ASB2016_init((self.N, self.N),name=self.name)

        # set the trainable weights
        if ('full' in self.unitary_impl):
            self.trainable_weights = [self.W, self.U, self.b, self.h0]
        elif (self.unitary_impl=='ASB2016'):
            self.trainable_weights = [self.W, self.theta, self.reflection, self.b, self.h0]

        self.regularizers = []
        #if self.W_regularizer:
        #    self.W_regularizer.set_param(self.W)
        #    self.regularizers.append(self.W_regularizer)
        #if self.U_regularizer:
        #    self.U_regularizer.set_param(self.U)
        #    self.regularizers.append(self.U_regularizer)
        #if self.b_regularizer:
        #    self.b_regularizer.set_param(self.b)
        #    self.regularizers.append(self.b_regularizer)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号