def __init__(self, synset_path, network_prefix, params_url=None, symbol_url=None, synset_url=None, context=mx.cpu(), label_names=['prob_label'], input_shapes=[('data', (1,3,224,224))]):
# Download the symbol set and network if URLs are provided
if params_url is not None:
print "fetching params from "+params_url
fetched_file = urllib2.urlopen(params_url)
with open(network_prefix+"-0000.params",'wb') as output:
output.write(fetched_file.read())
if symbol_url is not None:
print "fetching symbols from "+symbol_url
fetched_file = urllib2.urlopen(symbol_url)
with open(network_prefix+"-symbol.json",'wb') as output:
output.write(fetched_file.read())
if synset_url is not None:
print "fetching synset from "+synset_url
fetched_file = urllib2.urlopen(synset_url)
with open(synset_path,'wb') as output:
output.write(fetched_file.read())
# Load the symbols for the networks
with open(synset_path, 'r') as f:
self.synsets = [l.rstrip() for l in f]
# Load the network parameters from default epoch 0
sym, arg_params, aux_params = mx.model.load_checkpoint(network_prefix, 0)
# Load the network into an MXNet module and bind the corresponding parameters
self.mod = mx.mod.Module(symbol=sym, label_names=label_names, context=context)
self.mod.bind(for_training=False, data_shapes= input_shapes)
self.mod.set_params(arg_params, aux_params)
self.camera = None
评论列表
文章目录