def _get_module(args, margs, dargs, net=None):
if net is None:
# the following lines show how to create symbols for our networks
if model_specs['net_type'] == 'rna':
from util.symbol.symbol import cfg as symcfg
symcfg['lr_type'] = 'alex'
symcfg['workspace'] = dargs.mx_workspace
symcfg['bn_use_global_stats'] = True
if model_specs['net_name'] == 'a1':
from util.symbol.resnet_v2 import fcrna_model_a1
net = fcrna_model_a1(margs.classes, margs.feat_stride, bootstrapping=True)
if net is None:
raise NotImplementedError('Unknown network: {}'.format(vars(margs)))
contexts = [mx.gpu(int(_)) for _ in args.gpus.split(',')]
mod = mx.mod.Module(net, context=contexts)
return mod
评论列表
文章目录