def decompose_size(size):
"""Computes the number of input and output units for a weight shape.
Parameters
----------
size
Integer shape tuple.
Returns
-------
A tuple of scalars, `(fan_in, fan_out)`.
"""
if len(size) == 2:
fan_in = size[0]
fan_out = size[1]
elif len(size) == 4 or len(size) == 5:
respective_field_size = np.prod(size[2:])
fan_in = size[1] * respective_field_size
fan_out = size[0] * respective_field_size
else:
fan_in = fan_out = int(np.sqrt(np.prod(size)))
return fan_in, fan_out
评论列表
文章目录