transformer_net.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:neural-style 作者: jayanthkoushik 项目源码 文件源码
def get_transformer_net(X, weights=None):
    input_ = Input(tensor=X, shape=(3, 256, 256))
    y = conv_layer(input_, 32, 9)
    y = conv_layer(y, 64, 3, subsample=2)
    y = conv_layer(y, 128, 3, subsample=2)
    y = residual_block(y)
    y = residual_block(y)
    y = residual_block(y)
    y = residual_block(y)
    y = residual_block(y)
    y = conv_layer(y, 64, 3, upsample=2)
    y = conv_layer(y, 32, 3, upsample=2)
    y = conv_layer(y, 3, 9, only_conv=True)
    y = Activation("tanh")(y)
    y = Lambda(lambda x: x * 150, output_shape=(3, None, None))(y)

    net = Model(input=input_, output=y)
    if weights is not None:
        try:
            net.load_weights(weights)
        except OSError as e:
            print(e)
            sys.exit(1)
    return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号