def lin_interpolation_2d(inp, dim):
num_rows, num_cols, num_filters = K.int_shape(inp)[1:]
conv = SeparableConv2D(num_filters, (num_rows, num_cols), use_bias=False)
x = conv(inp)
w = conv.get_weights()
w[0].fill(0)
w[1].fill(0)
linspace = linspace_2d(num_rows, num_cols, dim=dim)
for i in range(num_filters):
w[0][:,:, i, 0] = linspace[:,:]
w[1][0, 0, i, i] = 1.
conv.set_weights(w)
conv.trainable = False
x = Lambda(lambda x: K.squeeze(x, axis=1))(x)
x = Lambda(lambda x: K.squeeze(x, axis=1))(x)
x = Lambda(lambda x: K.expand_dims(x, axis=-1))(x)
return x
评论列表
文章目录