def __init__(self, incoming, pattern, **kwargs):
super(DimshuffleLayer, self).__init__(incoming, **kwargs)
# Sanity check the pattern
used_dims = set()
for p in pattern:
if isinstance(p, int):
# Dimension p
if p in used_dims:
raise ValueError("pattern contains dimension {0} more "
"than once".format(p))
used_dims.add(p)
elif p == 'x':
# Broadcast
pass
else:
raise ValueError("pattern should only contain dimension"
"indices or 'x', not {0}".format(p))
self.pattern = pattern
# try computing the output shape once as a sanity check
self.get_output_shape_for(self.input_shape)
评论列表
文章目录