model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:enet-keras 作者: PavlosMelissinos 项目源码 文件源码
def build(nc, w, h,
          loss='categorical_crossentropy',
          optimizer='adadelta',
          plot=False,
          **kwargs):
    # data_shape = input_shape[0] * input_shape[1] if input_shape and None not in input_shape else None
    data_shape = w * h if None not in (w, h) else -1  # TODO: -1 or None?
    inp = Input(shape=(h, w, 3))
    shapes = valid_shapes(inp)

    if h < 161 or w < 161:
        errmsg = 'Input image tensor must be at least 161pxs in both width and height'
        raise ValueError(errmsg)

    out = encoder.build(inp, valid_shapes=shapes)
    out = decoder.build(inp=inp, encoder=out, nc=nc, valid_shapes=shapes)

    out = Reshape((data_shape, nc))(out)  # TODO: need to remove data_shape for multi-scale training
    out = Activation('softmax')(out)
    model = Model(inputs=inp, outputs=out)

    model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy', 'mean_squared_error'])
    name = 'icnet'

    if plot:
        plot_model(model, to_file='{}.png'.format(name), show_shapes=True)

    return model, name
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号