utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:mxnet_tk1 作者: starimpact 项目源码 文件源码
def replace_conv_layer(layer_name, old_model, sym_handle, arg_handle):
  conf = json.loads(old_model.symbol.tojson())
  sym_dict = {}
  nodes = conf['nodes']
  nodes = topsort(nodes)
  res_sym = None
  new_model = old_model  
  for i,node in enumerate(nodes):
    sym = None    
    if is_input(node):
      sym = mx.symbol.Variable(name='data')
    elif node['op'] != 'null':
      input_nodes = [nodes[int(j[0])] for j in node['inputs']]
      datas = [input_node['name'] for input_node in input_nodes\
                                  if not input_node['name'].startswith(node['name'])]
      try:
        data=sym_dict[datas[0]]
      except Exception, e:
        print 'can not find symbol %s'%(datas[0])
        raise e    
      if node['name'] == layer_name:
        sym = sym_handle(data, node)          
      else:
        sym = sym_factory(node, data)        
    if sym:
      sym_dict[node['name']] = sym
      res_sym = sym

  arg_params = copy.deepcopy(old_model.arg_params)
  if layer_name:  
    arg_shapes, _, _ = res_sym.infer_shape(data=(1,3,224,224))
    arg_names = res_sym.list_arguments()
    arg_shape_dic = dict(zip(arg_names, arg_shapes))
    try:
      arg_handle(arg_shape_dic, arg_params)
    except Exception, e:
      raise Exception('Exception in arg_handle')

  new_model = mx.model.FeedForward(
                symbol=res_sym,
                ctx=old_model.ctx,
                num_epoch=1,                                
                epoch_size=old_model.epoch_size,
                optimizer='sgd',
                initializer=old_model.initializer,
                numpy_batch_size=old_model.numpy_batch_size,
                arg_params=arg_params,
                aux_params=old_model.aux_params,
                allow_extra_params=True,
                begin_epoch=old_model.begin_epoch)  
  return new_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号