def replace_conv_layer(layer_name, old_model, sym_handle, arg_handle):
conf = json.loads(old_model.symbol.tojson())
sym_dict = {}
nodes = conf['nodes']
nodes = topsort(nodes)
res_sym = None
new_model = old_model
for i,node in enumerate(nodes):
sym = None
if is_input(node):
sym = mx.symbol.Variable(name='data')
elif node['op'] != 'null':
input_nodes = [nodes[int(j[0])] for j in node['inputs']]
datas = [input_node['name'] for input_node in input_nodes\
if not input_node['name'].startswith(node['name'])]
try:
data=sym_dict[datas[0]]
except Exception, e:
print 'can not find symbol %s'%(datas[0])
raise e
if node['name'] == layer_name:
sym = sym_handle(data, node)
else:
sym = sym_factory(node, data)
if sym:
sym_dict[node['name']] = sym
res_sym = sym
arg_params = copy.deepcopy(old_model.arg_params)
if layer_name:
arg_shapes, _, _ = res_sym.infer_shape(data=(1,3,224,224))
arg_names = res_sym.list_arguments()
arg_shape_dic = dict(zip(arg_names, arg_shapes))
try:
arg_handle(arg_shape_dic, arg_params)
except Exception, e:
raise Exception('Exception in arg_handle')
new_model = mx.model.FeedForward(
symbol=res_sym,
ctx=old_model.ctx,
num_epoch=1,
epoch_size=old_model.epoch_size,
optimizer='sgd',
initializer=old_model.initializer,
numpy_batch_size=old_model.numpy_batch_size,
arg_params=arg_params,
aux_params=old_model.aux_params,
allow_extra_params=True,
begin_epoch=old_model.begin_epoch)
return new_model
评论列表
文章目录