def load_models(args):
# load rnn model
caffe.set_mode_gpu()
if args.gpus is None:
caffe.set_device(args.job_id - 1)
else:
assert args.job_id <= len(args.gpus)
caffe.set_device(args.gpus[args.job_id-1])
if args.lstm_param is not '':
rnn_net = caffe.Net(args.lstm_def, args.lstm_param, caffe.TEST)
print 'Loaded RNN network from {:s}.'.format(args.lstm_def)
else:
rnn_net = caffe.Net(args.lstm_def, caffe.TEST)
print 'WARNING: dummy RNN network created.'
# load feature model
feature_net = caffe.Net(args.def_file, args.param, caffe.TEST)
print 'Loaded feature network from {:s}.'.format(args.def_file)
return feature_net, rnn_net
评论列表
文章目录