def get_transformer_net(X, weights=None):
input_ = Input(tensor=X, shape=(3, 256, 256))
y = conv_layer(input_, 32, 9)
y = conv_layer(y, 64, 3, subsample=2)
y = conv_layer(y, 128, 3, subsample=2)
y = residual_block(y)
y = residual_block(y)
y = residual_block(y)
y = residual_block(y)
y = residual_block(y)
y = conv_layer(y, 64, 3, upsample=2)
y = conv_layer(y, 32, 3, upsample=2)
y = conv_layer(y, 3, 9, only_conv=True)
y = Activation("tanh")(y)
y = Lambda(lambda x: x * 150, output_shape=(3, None, None))(y)
net = Model(input=input_, output=y)
if weights is not None:
try:
net.load_weights(weights)
except OSError as e:
print(e)
sys.exit(1)
return net
评论列表
文章目录