model_cnn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:denet 作者: lachlants 项目源码 文件源码
def initialize(args, data_shape, class_labels, class_num):

    cudnn_info=(theano.config.dnn.conv.algo_fwd, theano.config.dnn.conv.algo_bwd_data, theano.config.dnn.conv.algo_bwd_filter)
    logging.info("Using theano version:", theano.__version__, "(cudnn fwd=%s,bwd data=%s,bwd filter=%s)"%cudnn_info)
    if args.model is None:

        #construct convolutional model
        logging.info("Building convolutional model (%i classes)..."%class_num)
        model = ModelCNN()
        model.batch_size = args.batch_size
        model.class_labels = class_labels
        model.class_num = class_num

        #allow padding to be specified in border mode
        try:
            n = int(args.border_mode)
            border_mode = (n,n)
        except ValueError:
            border_mode = args.border_mode

        model.build(args.model_desc, data_shape, args.activation, border_mode, list(args.weight_init))
    else:
        model = load_from_file(args.model, args.batch_size)
        model.class_labels = class_labels
        model.class_num = class_num
        assert data_shape == model.data_shape, "Mismatching data shapes in .mdl and data: " + str(data_shape) + "!="  + str(model.data_shape)

    model.skip_layer_updates = args.skip_layer_updates
    if len(model.skip_layer_updates) > 0:
        logging.info("Skipping layer updates:", model.skip_layer_updates)

    return model

#
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号