def initialize(args, data_shape, class_labels, class_num):
cudnn_info=(theano.config.dnn.conv.algo_fwd, theano.config.dnn.conv.algo_bwd_data, theano.config.dnn.conv.algo_bwd_filter)
logging.info("Using theano version:", theano.__version__, "(cudnn fwd=%s,bwd data=%s,bwd filter=%s)"%cudnn_info)
if args.model is None:
#construct convolutional model
logging.info("Building convolutional model (%i classes)..."%class_num)
model = ModelCNN()
model.batch_size = args.batch_size
model.class_labels = class_labels
model.class_num = class_num
#allow padding to be specified in border mode
try:
n = int(args.border_mode)
border_mode = (n,n)
except ValueError:
border_mode = args.border_mode
model.build(args.model_desc, data_shape, args.activation, border_mode, list(args.weight_init))
else:
model = load_from_file(args.model, args.batch_size)
model.class_labels = class_labels
model.class_num = class_num
assert data_shape == model.data_shape, "Mismatching data shapes in .mdl and data: " + str(data_shape) + "!=" + str(model.data_shape)
model.skip_layer_updates = args.skip_layer_updates
if len(model.skip_layer_updates) > 0:
logging.info("Skipping layer updates:", model.skip_layer_updates)
return model
#
评论列表
文章目录