def Unroll(axis, num=None):
"""
defines an _OperationalLayer that unpacks a tensor along a given axis
Parameters:
----------
axis: int
num: int
the numeber if tensors to unpack form the gievn tensor
Returns: _OperationalLayer
"""
def unroll_op(obj, X):
return tf.unpack(X, obj.params[0], 1)
return _OperationalLayer(unroll_op, [num, axis])
评论列表
文章目录