models.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:diagnose-heart 作者: woshialex 项目源码 文件源码
def build_fcn_segmenter(input_var, shape, version=1):
    ret = {}
    if version == 1: #for size 256
        ret['input'] = layer = nn.layers.InputLayer(shape, input_var)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=8, filter_size=5))
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=16, filter_size=3))
        ret['pool{}'.format(len(ret))] = layer = nn.layers.MaxPool2DLayer(layer, pool_size=2)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=32, filter_size=4))
        ret['pool{}'.format(len(ret))] = layer = nn.layers.MaxPool2DLayer(layer, pool_size=2)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=64, filter_size=4))
        ret['pool{}'.format(len(ret))] = layer = nn.layers.MaxPool2DLayer(layer, pool_size=2)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=64, filter_size=5))
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=64, filter_size=5, pad='full'))
        ret['ups{}'.format(len(ret))] = layer = nn.layers.Upscale2DLayer(layer, scale_factor=2)
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=32, filter_size=4, pad='full'))
        ret['ups{}'.format(len(ret))] = layer = nn.layers.Upscale2DLayer(layer, scale_factor=2)
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=16, filter_size=4, pad='full'))
        ret['ups{}'.format(len(ret))] = layer = nn.layers.Upscale2DLayer(layer, scale_factor=2)
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=8, filter_size=3, pad='full'))
        ret['output'] = layer = nn.layers.Conv2DLayer(layer, num_filters=1, filter_size=5, pad='full',
                                                     nonlinearity=nn.nonlinearities.sigmoid)
    elif version == 2: #for size 196
        ret['input'] = layer = nn.layers.InputLayer(shape, input_var)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=8, filter_size=5))
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=16, filter_size=3))
        ret['pool{}'.format(len(ret))] = layer = nn.layers.MaxPool2DLayer(layer, pool_size=2)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=32, filter_size=4))
        ret['pool{}'.format(len(ret))] = layer = nn.layers.MaxPool2DLayer(layer, pool_size=2)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=64, filter_size=5))
        ret['pool{}'.format(len(ret))] = layer = nn.layers.MaxPool2DLayer(layer, pool_size=2)
        ret['conv{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=128, filter_size=6))
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=64, filter_size=6, pad='full'))
        ret['ups{}'.format(len(ret))] = layer = nn.layers.Upscale2DLayer(layer, scale_factor=2)
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=32, filter_size=5, pad='full'))
        ret['ups{}'.format(len(ret))] = layer = nn.layers.Upscale2DLayer(layer, scale_factor=2)
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=16, filter_size=4, pad='full'))
        ret['ups{}'.format(len(ret))] = layer = nn.layers.Upscale2DLayer(layer, scale_factor=2)
        ret['dec{}'.format(len(ret))] = layer = bn(nn.layers.Conv2DLayer(layer, num_filters=8, filter_size=3, pad='full'))
        ret['output'] = layer = nn.layers.Conv2DLayer(layer, num_filters=1, filter_size=5, pad='full',
                                                     nonlinearity=nn.nonlinearities.sigmoid)

    return ret, nn.layers.get_output(ret['output']), \
            nn.layers.get_output(ret['output'], deterministic=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号