def assign_attr_value(attr, val):
from mmdnn.conversion.common.IR.graph_pb2 import TensorShape
'''Assign value to AttrValue proto according to data type.'''
if isinstance(val, bool):
attr.b = val
elif isinstance(val, integer_types):
attr.i = val
elif isinstance(val, float):
attr.f = val
elif isinstance(val, binary_type) or isinstance(val, text_type):
if hasattr(val, 'encode'):
val = val.encode()
attr.s = val
elif isinstance(val, TensorShape):
attr.shape.MergeFromString(val.SerializeToString())
elif isinstance(val, list):
if not val: return
if isinstance(val[0], integer_types):
attr.list.i.extend(val)
elif isinstance(val[0], TensorShape):
attr.list.shape.extend(val)
else:
raise NotImplementedError('AttrValue cannot be of list[{}].'.format(val[0]))
else:
raise NotImplementedError('AttrValue cannot be of %s' % type(val))
评论列表
文章目录