layers.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keras-neural-graph-fingerprint 作者: keiserlab 项目源码 文件源码
def __init__(self, inner_layer_arg, **kwargs):
        # Initialise based on one of the three initialisation methods

        # Case 1: Check if inner_layer_arg is conv_width
        if isinstance(inner_layer_arg, (int, long)):
            self.conv_width = inner_layer_arg
            dense_layer_kwargs, kwargs = filter_func_args(layers.Dense.__init__,
            kwargs, overrule_args=['name'])
            self.create_inner_layer_fn = lambda: layers.Dense(self.conv_width, **dense_layer_kwargs)

        # Case 2: Check if an initialised keras layer is given
        elif isinstance(inner_layer_arg, layers.Layer):
            assert inner_layer_arg.built == False, 'When initialising with a keras layer, it cannot be built.'
            _, self.conv_width = inner_layer_arg.get_output_shape_for((None, None))
            # layer_from_config will mutate the config dict, therefore create a get fn
            self.create_inner_layer_fn = lambda: layer_from_config(dict(
                                                    class_name=inner_layer_arg.__class__.__name__,
                                                    config=inner_layer_arg.get_config()))

        # Case 3: Check if a function is provided that returns a initialised keras layer
        elif callable(inner_layer_arg):
            example_instance = inner_layer_arg()
            assert isinstance(example_instance, layers.Layer), 'When initialising with a function, the function has to return a keras layer'
            assert example_instance.built == False, 'When initialising with a keras layer, it cannot be built.'
            _, self.conv_width = example_instance.get_output_shape_for((None, None))
            self.create_inner_layer_fn = inner_layer_arg

        else:
            raise ValueError('NeuralGraphHidden has to be initialised with 1). int conv_widht, 2). a keras layer instance, or 3). a function returning a keras layer instance.')

        super(NeuralGraphHidden, self).__init__(**kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号