train.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:dilation 作者: fyu 项目源码 文件源码
def make_joint(options, is_training):
    batch_size = options.train_batch if is_training else options.test_batch
    image_path = options.train_image if is_training else options.test_image
    label_path = options.train_label if is_training else options.test_label
    net = caffe.NetSpec()
    net.data, net.label = network.make_image_label_data(
        image_path, label_path, batch_size,
        is_training, options.crop_size, options.mean)
    last = network.build_frontend_vgg(
        net, net.data, options.classes)[0]
    last = network.build_context(
        net, last, options.classes, options.layers)[0]
    if options.up:
        net.upsample = network.make_upsample(last, options.classes)
        last = net.upsample
    net.loss = network.make_softmax_loss(last, net.label)
    if not is_training:
        net.accuracy = network.make_accuracy(last, net.label)
    return net.to_proto()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号