util.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-neural-style 作者: dsanno 项目源码 文件源码
def nearest_neighbor_patch(x, patch, patch_norm):
    assert patch.data.shape[0] == 1, 'mini batch size of patch must be 1'
    assert patch_norm.data.shape[0] == 1, 'mini batch size of patch_norm must be 1'

    xp = cuda.get_array_module(x.data)
    z = x.data
    b, ch, h, w = z.shape
    z = z.transpose((1, 0, 2, 3)).reshape((ch, -1))
    norm = xp.expand_dims(xp.sum(z ** 2, axis=0) ** 0.5, 0)
    z = z / xp.broadcast_to(norm, z.shape)
    p = patch.data
    p_norm = patch_norm.data
    p = p.reshape((ch, -1))
    p_norm = p_norm.reshape((1, -1))
    p_normalized = p / xp.broadcast_to(p_norm, p.shape)
    correlation = z.T.dot(p_normalized)
    min_index = xp.argmax(correlation, axis=1)
    nearest_neighbor = p.take(min_index, axis=1).reshape((ch, b, h, w)).transpose((1, 0, 2, 3))
    return Variable(nearest_neighbor)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号