def __init__(self, inner_layer_arg, **kwargs):
# Initialise based on one of the three initialisation methods
# Case 1: Check if inner_layer_arg is conv_width
if isinstance(inner_layer_arg, (int, long)):
self.conv_width = inner_layer_arg
dense_layer_kwargs, kwargs = filter_func_args(layers.Dense.__init__,
kwargs, overrule_args=['name'])
self.create_inner_layer_fn = lambda: layers.Dense(self.conv_width, **dense_layer_kwargs)
# Case 2: Check if an initialised keras layer is given
elif isinstance(inner_layer_arg, layers.Layer):
assert inner_layer_arg.built == False, 'When initialising with a keras layer, it cannot be built.'
_, self.conv_width = inner_layer_arg.get_output_shape_for((None, None))
# layer_from_config will mutate the config dict, therefore create a get fn
self.create_inner_layer_fn = lambda: layer_from_config(dict(
class_name=inner_layer_arg.__class__.__name__,
config=inner_layer_arg.get_config()))
# Case 3: Check if a function is provided that returns a initialised keras layer
elif callable(inner_layer_arg):
example_instance = inner_layer_arg()
assert isinstance(example_instance, layers.Layer), 'When initialising with a function, the function has to return a keras layer'
assert example_instance.built == False, 'When initialising with a keras layer, it cannot be built.'
_, self.conv_width = example_instance.get_output_shape_for((None, None))
self.create_inner_layer_fn = inner_layer_arg
else:
raise ValueError('NeuralGraphHidden has to be initialised with 1). int conv_widht, 2). a keras layer instance, or 3). a function returning a keras layer instance.')
super(NeuralGraphHidden, self).__init__(**kwargs)
layers.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录