python类InputSpec()的实例源码

recurrent.py 文件源码 项目:keras_bn_library 作者: bnsnapper 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        self.input_dim = input_shape[2]

        self.W = self.init((self.output_dim, 4 * self.input_dim),
                           name='{}_W'.format(self.name))
        self.U = self.inner_init((self.input_dim, 4 * self.input_dim),
                                 name='{}_U'.format(self.name))
        self.b = K.variable(np.hstack((np.zeros(self.input_dim),
                                       K.get_value(self.forget_bias_init((self.input_dim,))),
                                       np.zeros(self.input_dim),
                                       np.zeros(self.input_dim))),
                            name='{}_b'.format(self.name))

        self.A = self.init((self.input_dim, self.output_dim),
                            name='{}_A'.format(self.name))
        self.ba = K.zeros((self.output_dim,), name='{}_ba'.format(self.name))


        self.trainable_weights = [self.W, self.U, self.b, self.A, self.ba]

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
layers.py 文件源码 项目:keras-fcn 作者: JihongJu 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def __init__(self, target_shape, offset=None, data_format=None,
                 **kwargs):
        """Crop to target.

        If only one `offset` is set, then all dimensions are offset by this amount.

        """
        super(CroppingLike2D, self).__init__(**kwargs)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.target_shape = target_shape
        if offset is None or offset == 'centered':
            self.offset = 'centered'
        elif isinstance(offset, int):
            self.offset = (offset, offset)
        elif hasattr(offset, '__len__'):
            if len(offset) != 2:
                raise ValueError('`offset` should have two elements. '
                                 'Found: ' + str(offset))
            self.offset = offset
        self.input_spec = InputSpec(ndim=4)
layers.py 文件源码 项目:mctest-model 作者: Maluuba 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        assert len(input_shape) >= 3
        self.input_spec = [InputSpec(shape=input_shape)]
        if K._BACKEND == 'tensorflow':
            if not input_shape[1]:
                raise Exception('When using TensorFlow, you should define '
                                'explicitly the number of timesteps of '
                                'your sequences.\n'
                                'If your first layer is an Embedding, '
                                'make sure to pass it an "input_length" '
                                'argument. Otherwise, make sure '
                                'the first layer has '
                                'an "input_shape" or "batch_input_shape" '
                                'argument, including the time axis.')
        child_input_shape = (input_shape[0],) + input_shape[self.first_n:]
        if not self.layer.built:
            self.layer.build(child_input_shape)
            self.layer.built = True
        super(NTimeDistributed, self).build()
onto_attention.py 文件源码 项目:onto-lstm 作者: pdasigi 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        input_dim = input_shape[4] - 1  # ignore sense prior parameter
        self.input_dim = input_dim
        # Saving onto-lstm weights to set them later. This way, LSTM's build method won't 
        # delete them.
        initial_ontolstm_weights = self.initial_weights
        self.initial_weights = None
        lstm_input_shape = input_shape[:2] + (input_dim,) # removing senses and hyps
        # Now calling LSTM's build to initialize the LSTM weights
        super(OntoAttentionLSTM, self).build(lstm_input_shape)
        # This would have changed the input shape and ndim. Reset it again.
        self.input_spec = [InputSpec(shape=input_shape)]

        if self.use_attention:
            # Following are the attention parameters
            self.input_hyp_projector = self.inner_init((input_dim, self.output_dim),
                name='{}_input_hyp_projector'.format(self.name)) # Projection operator for synsets
            self.context_hyp_projector = self.inner_init((self.output_dim, self.output_dim),
                name='{}_context_hyp_projector'.format(self.name)) # Projection operator for hidden state (context)
            self.hyp_projector2 = self.inner_init((self.output_dim, self.output_dim),
                name='{}_hyp_projector2'.format(self.name)) # Projection operator for hidden state (context)
            self.hyp_scorer = self.init((self.output_dim,), name='{}_hyp_scorer'.format(self.name))

            # LSTM's build method would have initialized trainable_weights. Add to it.
            self.trainable_weights.extend([self.input_hyp_projector, self.context_hyp_projector,
                                           self.hyp_projector2, self.hyp_scorer])


        if initial_ontolstm_weights is not None:
            self.set_weights(initial_ontolstm_weights)
            del initial_ontolstm_weights
onto_attention.py 文件源码 项目:onto-lstm 作者: pdasigi 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, num_senses, num_hyps, use_attention=False, return_attention=False, **kwargs):
        assert "output_dim" in kwargs
        output_dim = kwargs.pop("output_dim")
        super(OntoAttentionNSE, self).__init__(output_dim, **kwargs)
        self.input_spec = [InputSpec(ndim=5)]
        # TODO: Define an attention output method that rebuilds the reader.
        self.return_attention = return_attention
        self.reader = OntoAttentionLSTM(self.output_dim, num_senses, num_hyps, use_attention=use_attention,
                                        consume_less='gpu', return_attention=False)
nse.py 文件源码 项目:onto-lstm 作者: pdasigi 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def __init__(self, output_dim, input_length=None, composer_activation='linear',
                 return_mode='last_output', weights=None, **kwargs):
        '''
        Arguments:
        output_dim (int)
        input_length (int)
        composer_activation (str): activation used in the MLP
        return_mode (str): One of last_output, all_outputs, output_and_memory
            This is analogous to the return_sequences flag in Keras' Recurrent.
            last_output returns only the last h_t
            all_outputs returns the whole sequence of h_ts
            output_and_memory returns the last output and the last memory concatenated
                (needed if this layer is followed by a MMA-NSE)
        weights (list): Initial weights
        '''
        self.output_dim = output_dim
        self.input_dim = output_dim  # Equation 2 in the paper makes this assumption.
        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=3)]
        self.input_length = input_length
        self.composer_activation = composer_activation
        super(NSE, self).__init__(**kwargs)
        self.reader = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="gpu",
                           name="{}_reader".format(self.name))
        # TODO: Let the writer use parameter dropout and any consume_less mode.
        # Setting dropout to 0 here to eliminate the need for constants.
        # Setting consume_less to gpu to eliminate need for preprocessing
        self.writer = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="gpu",
                           name="{}_writer".format(self.name))
        self.composer = Dense(self.output_dim * 2, activation=self.composer_activation,
                              name="{}_composer".format(self.name))
        if return_mode not in ["last_output", "all_outputs", "output_and_memory"]:
            raise Exception("Unrecognized return mode: %s" % (return_mode))
        self.return_mode = return_mode
layers.py 文件源码 项目:recurrent-attention-for-QA-SQUAD-based-on-keras 作者: wentaozhu 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        assert len(input_shape) >= 3
        self.input_spec = [InputSpec(shape=input_shape)]

        if not self.layer.built:
            self.layer.build(input_shape)
            self.layer.built = True

        super(AttentionLSTMWrapper, self).build()

        if hasattr(self.attention_vec, '_keras_shape'):
            attention_dim = self.attention_vec._keras_shape[1]
        else:
            raise Exception('Layer could not be build: No information about expected input shape.')

        self.U_a = self.layer.inner_init((self.layer.output_dim, self.layer.output_dim), name='{}_U_a'.format(self.name))
        self.b_a = K.zeros((self.layer.output_dim,), name='{}_b_a'.format(self.name))

        self.U_m = self.layer.inner_init((attention_dim, self.layer.output_dim), name='{}_U_m'.format(self.name))
        self.b_m = K.zeros((self.layer.output_dim,), name='{}_b_m'.format(self.name))

        if self.single_attention_param:
            self.U_s = self.layer.inner_init((self.layer.output_dim, 1), name='{}_U_s'.format(self.name))
            self.b_s = K.zeros((1,), name='{}_b_s'.format(self.name))
        else:
            self.U_s = self.layer.inner_init((self.layer.output_dim, self.layer.output_dim), name='{}_U_s'.format(self.name))
            self.b_s = K.zeros((self.layer.output_dim,), name='{}_b_s'.format(self.name))

        self.trainable_weights = [self.U_a, self.U_m, self.U_s, self.b_a, self.b_m, self.b_s]
attentionlayer.py 文件源码 项目:recurrent-attention-for-QA-SQUAD-based-on-keras 作者: wentaozhu 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __init__(self, h, output_dim,
                 init='glorot_uniform', **kwargs):
        self.init = initializations.get(init)
        self.h = h
        self.output_dim = output_dim
        #removing the regularizers and the dropout
        super(AttenLayer, self).__init__(**kwargs)
        # this seems necessary in order to accept 3 input dimensions
        # (samples, timesteps, features)
        self.input_spec=[InputSpec(ndim=3)]
QnA.py 文件源码 项目:recurrent-attention-for-QA-SQUAD-based-on-keras 作者: wentaozhu 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        assert len(input_shape) >= 3
        self.input_spec = [InputSpec(shape=input_shape)]

        if not self.layer.built:
            self.layer.build(input_shape)
            self.layer.built = True

        super(AttentionLSTMWrapper, self).build()

        if hasattr(self.attention_vec, '_keras_shape'):
            attention_dim = self.attention_vec._keras_shape[1]
        else:
            raise Exception('Layer could not be build: No information about expected input shape.')

        self.U_a = self.layer.inner_init((self.layer.output_dim, self.layer.output_dim), name='{}_U_a'.format(self.name))
        self.b_a = K.zeros((self.layer.output_dim,), name='{}_b_a'.format(self.name))

        self.U_m = self.layer.inner_init((attention_dim, self.layer.output_dim), name='{}_U_m'.format(self.name))
        self.b_m = K.zeros((self.layer.output_dim,), name='{}_b_m'.format(self.name))

        if self.single_attention_param:
            self.U_s = self.layer.inner_init((self.layer.output_dim, 1), name='{}_U_s'.format(self.name))
            self.b_s = K.zeros((1,), name='{}_b_s'.format(self.name))
        else:
            self.U_s = self.layer.inner_init((self.layer.output_dim, self.layer.output_dim), name='{}_U_s'.format(self.name))
            self.b_s = K.zeros((self.layer.output_dim,), name='{}_b_s'.format(self.name))

        self.trainable_weights = [self.U_a, self.U_m, self.U_s, self.b_a, self.b_m, self.b_s]
layer_norm_layers.py 文件源码 项目:nn_playground 作者: DingKe 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = [1 for _ in input_shape]
        for i in self.axis:
            shape[i] = input_shape[i]
        self.gamma = self.add_weight(shape=shape,
                                     initializer=self.gamma_init,
                                     regularizer=self.gamma_regularizer,
                                     name='gamma')
        self.beta = self.add_weight(shape=shape,
                                    initializer=self.beta_init,
                                    regularizer=self.beta_regularizer,
                                    name='beta')
        self.built = True
weight_norm_layers.py 文件源码 项目:nn_playground 作者: DingKe 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.g = self.add_weight(shape=(1, 1, 1, self.filters),
                                 initializer='one',
                                 name='g')
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        # Set input spec.
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_dim})

        self.built = True
FixedBatchNormalization.py 文件源码 项目:AerialCrackDetection_Keras 作者: TTMRonald 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (input_shape[self.axis],)

        self.gamma = self.add_weight(shape,
                                     initializer=self.gamma_init,
                                     regularizer=self.gamma_regularizer,
                                     name='{}_gamma'.format(self.name),
                                     trainable=False)
        self.beta = self.add_weight(shape,
                                    initializer=self.beta_init,
                                    regularizer=self.beta_regularizer,
                                    name='{}_beta'.format(self.name),
                                    trainable=False)
        self.running_mean = self.add_weight(shape, initializer='zero',
                                            name='{}_running_mean'.format(self.name),
                                            trainable=False)
        self.running_std = self.add_weight(shape, initializer='one',
                                           name='{}_running_std'.format(self.name),
                                           trainable=False)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

        self.built = True
FixedBatchNormalization.py 文件源码 项目:keras-frcnn 作者: yhenon 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (input_shape[self.axis],)

        self.gamma = self.add_weight(shape,
                                     initializer=self.gamma_init,
                                     regularizer=self.gamma_regularizer,
                                     name='{}_gamma'.format(self.name),
                                     trainable=False)
        self.beta = self.add_weight(shape,
                                    initializer=self.beta_init,
                                    regularizer=self.beta_regularizer,
                                    name='{}_beta'.format(self.name),
                                    trainable=False)
        self.running_mean = self.add_weight(shape, initializer='zero',
                                            name='{}_running_mean'.format(self.name),
                                            trainable=False)
        self.running_std = self.add_weight(shape, initializer='one',
                                           name='{}_running_std'.format(self.name),
                                           trainable=False)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

        self.built = True
mobilenet.py 文件源码 项目:deep-learning-models 作者: fchollet 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        if len(input_shape) < 4:
            raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
                             'Received input shape:', str(input_shape))
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = 3
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs to '
                             '`DepthwiseConv2D` '
                             'should be defined. Found `None`.')
        input_dim = int(input_shape[channel_axis])
        depthwise_kernel_shape = (self.kernel_size[0],
                                  self.kernel_size[1],
                                  input_dim,
                                  self.depth_multiplier)

        self.depthwise_kernel = self.add_weight(
            shape=depthwise_kernel_shape,
            initializer=self.depthwise_initializer,
            name='depthwise_kernel',
            regularizer=self.depthwise_regularizer,
            constraint=self.depthwise_constraint)

        if self.use_bias:
            self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        # Set input spec.
        self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
        self.built = True
renormalization.py 文件源码 项目:DeepTrade_keras 作者: happynoom 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (input_shape[self.axis],)

        self.gamma = self.add_weight(shape,
                                     initializer=self.gamma_init,
                                     regularizer=self.gamma_regularizer,
                                     name='{}_gamma'.format(self.name))
        self.beta = self.add_weight(shape,
                                    initializer=self.beta_init,
                                    regularizer=self.beta_regularizer,
                                    name='{}_beta'.format(self.name))
        self.running_mean = self.add_weight(shape, initializer='zero',
                                            name='{}_running_mean'.format(self.name),
                                            trainable=False)
        # Note: running_std actually holds the running variance, not the running std.
        self.running_std = self.add_weight(shape, initializer='one',
                                           name='{}_running_std'.format(self.name),
                                           trainable=False)

        self.r_max = K.variable(np.ones((1,)), name='{}_r_max'.format(self.name))

        self.d_max = K.variable(np.zeros((1,)), name='{}_d_max'.format(self.name))

        self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name))

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True
local.py 文件源码 项目:keras 作者: GeekLiB 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def __init__(self, nb_filter, filter_length,
                 init='glorot_uniform', activation=None, weights=None,
                 border_mode='valid', subsample_length=1,
                 W_regularizer=None, b_regularizer=None, activity_regularizer=None,
                 W_constraint=None, b_constraint=None,
                 bias=True, input_dim=None, input_length=None, **kwargs):
        if border_mode != 'valid':
            raise Exception('Invalid border mode for LocallyConnected1D '
                            '(only "valid" is supported):', border_mode)
        self.nb_filter = nb_filter
        self.filter_length = filter_length
        self.init = initializations.get(init, dim_ordering='th')
        self.activation = activations.get(activation)

        self.border_mode = border_mode
        self.subsample_length = subsample_length

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias
        self.input_spec = [InputSpec(ndim=3)]
        self.initial_weights = weights
        self.input_dim = input_dim
        self.input_length = input_length
        if self.input_dim:
            kwargs['input_shape'] = (self.input_length, self.input_dim)
        super(LocallyConnected1D, self).__init__(**kwargs)
local.py 文件源码 项目:keras 作者: GeekLiB 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def __init__(self, nb_filter, nb_row, nb_col,
                 init='glorot_uniform', activation=None, weights=None,
                 border_mode='valid', subsample=(1, 1),
                 dim_ordering='default',
                 W_regularizer=None, b_regularizer=None, activity_regularizer=None,
                 W_constraint=None, b_constraint=None,
                 bias=True, **kwargs):
        if dim_ordering == 'default':
            dim_ordering = K.image_dim_ordering()
        if border_mode != 'valid':
            raise Exception('Invalid border mode for LocallyConnected2D '
                            '(only "valid" is supported):', border_mode)
        self.nb_filter = nb_filter
        self.nb_row = nb_row
        self.nb_col = nb_col
        self.init = initializations.get(init, dim_ordering=dim_ordering)
        self.activation = activations.get(activation)

        self.border_mode = border_mode
        self.subsample = tuple(subsample)
        assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
        self.dim_ordering = dim_ordering

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias
        self.input_spec = [InputSpec(ndim=4)]
        self.initial_weights = weights
        super(LocallyConnected2D, self).__init__(**kwargs)
convolutional.py 文件源码 项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def __init__(self, filters, kernel_size,
                 kernel_initializer='glorot_uniform', activation=None, weights=None,
                 padding='valid', strides=(1, 1), data_format=None,
                 kernel_regularizer=None, bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None, bias_constraint=None,
                 use_bias=True, **kwargs):
        if data_format is None:
            data_format = K.image_data_format()
        if padding not in {'valid', 'same', 'full'}:
            raise ValueError('Invalid border mode for CosineConvolution2D:', padding)
        self.filters = filters
        self.kernel_size = kernel_size
        self.nb_row, self.nb_col = self.kernel_size
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.activation = activations.get(activation)
        self.padding = padding
        self.strides = tuple(strides)
        self.data_format = normalize_data_format(data_format)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.use_bias = use_bias
        self.input_spec = [InputSpec(ndim=4)]
        self.initial_weights = weights
        super(CosineConvolution2D, self).__init__(**kwargs)
normalization.py 文件源码 项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        ndim = len(input_shape)
        if self.axis == 0:
            raise ValueError('Axis cannot be zero')

        if (self.axis is not None) and (ndim == 2):
            raise ValueError('Cannot specify axis for rank 1 tensor')

        self.input_spec = InputSpec(ndim=ndim)

        if self.axis is None:
            shape = (1,)
        else:
            shape = (input_shape[self.axis],)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True
advanced_activations.py 文件源码 项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def build(self, input_shape):
        param_shape = list(input_shape[1:])
        self.param_broadcast = [False] * len(param_shape)
        if self.shared_axes is not None:
            for i in self.shared_axes:
                param_shape[i - 1] = 1
                self.param_broadcast[i - 1] = True

        param_shape = tuple(param_shape)
        # Initialised as ones to emulate the default ELU
        self.alpha = self.add_weight(param_shape,
                                     name='alpha',
                                     initializer=self.alpha_initializer,
                                     regularizer=self.alpha_regularizer,
                                     constraint=self.alpha_constraint)
        self.beta = self.add_weight(param_shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)

        # Set input spec
        axes = {}
        if self.shared_axes:
            for i in range(1, len(input_shape)):
                if i not in self.shared_axes:
                    axes[i] = input_shape[i]
        self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
        self.built = True


问题


面经


文章

微信
公众号

扫码关注公众号