Title

Guided Image Filtering

Summary

作者基于bilateral filter提出guided filter,更好地在保护边缘的同时平滑图像,可用于降噪,上色,消色,去雾,HDR压缩等多方面

Research Objective

提出一个更好的保护边缘的平滑方法

Background and Problems

降噪或提取图像有效结构有广泛应用。一些explicit linear translation-invariant(LTI) filters(Gaussian filter, Laplacian filter, and Sobel filter)被广泛应用在模糊/锐化,边缘检测,特征提取,high dynamic range(HDR)压缩,图像拼接和图像消光。

潜在的问题:

  • 滤波时引入给定的引导图像的信息,比如亮度通道有一致的边缘,保护细的结构如头发信息。
    • 方法一:优化二次函数,基于引导图像对未知输出加以约束,通过求解引导图的稀疏矩阵获得解。这种方法被广泛用在着色,消光,除雾,多尺度分解。效果不错,计算时间长。
    • 方法二:基于引导图构建滤波器。bilateral filter应该是其中最流行的。它的输出是周围像素的加权平均,而权重来自引导图像中强度、颜色相似度。这种卷积可以平滑小波动,保留边缘,但可能造成伪影,无法快速实现。当时有很多加速这个方法的研究。
    • 方法一和方法二高度相关,方法一是隐式滤波,方法二是显式滤波,通常显式滤波会比基于优化的方法更简单更快。

Method(s)

  • 新的explicit image filter,称之为guided filter。卷积的输出是引导图的线性变化参数。这种卷积和bilateral filter一样能保护边缘地平滑,但不会受到梯度反转伪影的影响。同时它也和matting laplacian matrix相关。

image.png

  1. 引导滤波的重要假设是输出图和引导图在滤波窗口上存在局部线性关系。
  2. 输入图和输出图的差值应视为噪声
  3. 优化目标是使噪声最小,用最小二乘最优化方法(岭回归)求解线性关系的参数
  4. 对多个卷积窗口的输出图结果求均值(推导时是这么推的,实际应用的时候是整图操作)

image.png

参考:

Evaluation

作者做了些实验,列举了大量图片的细节,证明guided filter在细节增强、HDR压缩、降噪、消光/引导羽毛、去雾、上采样、上色等方面做的非常好。

Conclusion

  • guided filter能在平滑时保护边缘,计算有效快速(O(N))
  • guided filter有超越降噪的多方面应用,
  • guided filter有所有explicit filter的共有限制:会在边缘有光晕,因为滤波器无法区分哪些要平滑哪些要保护,但是瑕不掩瑜,这仍然是个好方法

Reference

  • Bilateral filtering for gray and color images. ICCV (1998)
  • Non-linear gaussian filters performing edge preserving diffusion.(1995)
  • Digital photography with flash and no-flash image pairs(2004) joint bilateral filter

Code

提供两个pytorch版本的,原始代码来自:https://github.com/wuhuikai/DeepGuidedFilter

torch 基础版本

  1. import torch
  2. from torch import nn
  3. import cv2
  4. import numpy as np
  5. from torch.autograd import Variable
  6. def cv2tensor(x, is01=False):
  7. if is01:
  8. x = x / 255. # x: [0, 1]
  9. else:
  10. x = x / 255. * 2 - 1 # x: [-1, 1]
  11. x = torch.from_numpy(x.transpose((2, 0, 1))).float().unsqueeze(0)
  12. return x
  13. def save_tensor_img(tensor, save_path, is01=False):
  14. img_numpy = tensor.detach().cpu().numpy().transpose((1,2,0))
  15. if is01:
  16. img_numpy = img_numpy * 255.0
  17. else:
  18. img_numpy = (img_numpy * 0.5 + 0.5) * 255.0
  19. img_numpy = np.clip(img_numpy, 0., 255.)
  20. img_output = img_numpy.astype(np.uint8)
  21. img_output = cv2.cvtColor(img_output, cv2.COLOR_RGB2BGR)
  22. #img_output = cv2.resize(img_output, (512, 512))
  23. cv2.imwrite(save_path, img_output)
  24. def diff_x(input, r):
  25. assert input.dim() == 4
  26. left = input[:, :, r:2 * r + 1]
  27. middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
  28. right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
  29. output = torch.cat([left, middle, right], dim=2)
  30. return output
  31. def diff_y(input, r):
  32. assert input.dim() == 4
  33. left = input[:, :, :, r:2 * r + 1]
  34. middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
  35. right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
  36. output = torch.cat([left, middle, right], dim=3)
  37. return output
  38. class BoxFilter(nn.Module):
  39. def __init__(self, r):
  40. super(BoxFilter, self).__init__()
  41. self.r = r
  42. def forward(self, x):
  43. assert x.dim() == 4
  44. # numpy.cumsum给定轴上的和
  45. return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
  46. class GuidedFilter(nn.Module):
  47. def __init__(self, r, eps=1e-8):
  48. super(GuidedFilter, self).__init__()
  49. self.r = r
  50. self.eps = eps
  51. self.boxfilter = BoxFilter(r)
  52. def forward(self, x, y):
  53. n_x, c_x, h_x, w_x = x.size()
  54. n_y, c_y, h_y, w_y = y.size()
  55. assert n_x == n_y
  56. assert c_x == 1 or c_x == c_y
  57. assert h_x == h_y and w_x == w_y
  58. assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
  59. # N
  60. N = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
  61. # mean_x
  62. mean_x = self.boxfilter(x) / N
  63. # mean_y
  64. mean_y = self.boxfilter(y) / N
  65. # cov_xy
  66. cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
  67. # var_x
  68. var_x = self.boxfilter(x * x) / N - mean_x * mean_x
  69. # 这里是求A和b的重点
  70. # A
  71. A = cov_xy / (var_x + self.eps)
  72. # b
  73. b = mean_y - A * mean_x
  74. # mean_A; mean_b
  75. mean_A = self.boxfilter(A) / N
  76. mean_b = self.boxfilter(b) / N
  77. return mean_A * x + mean_b

torch 优化版本

把好多除法给删除了,把N当作权重,只在初始化的时候算一次(这样的话需要指定好宽高)

class GuidedFilter2(nn.Module):
    def __init__(self, r, eps=1e-8, height=512, width=512):
        super(GuidedFilter2, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)
        self.N = 1/self.boxfilter(torch.ones((1, 3, height, width)))

    def forward(self, x, y):
        n_x, c_x, h_x, w_x = x.size()
        n_y, c_y, h_y, w_y = y.size()

        assert n_x == n_y
        assert c_x == 1 or c_x == c_y
        assert h_x == h_y and w_x == w_y
        assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1

        # mean_x
        self.N = self.N.to(x.device)
        mean_x = self.boxfilter(x) * self.N
        # mean_y
        mean_y = self.boxfilter(y) * self.N
        # cov_xy
        cov_xy = self.boxfilter(x * y) * self.N - mean_x * mean_y
        # var_x
        var_x = self.boxfilter(x * x) * self.N - mean_x * mean_x

        # 这里是求A和b的重点
        # A
        A = cov_xy / (var_x + self.eps)
        # b
        b = mean_y - A * mean_x

        # mean_A; mean_b
        mean_A = self.boxfilter(A) * self.N
        mean_b = self.boxfilter(b) * self.N
        out = mean_A * x + mean_b
        return out

输入同一张图的torch优化版本

我将guided filter当作保边滤波使用,引导图也是输入图,因此进一步做了相应的改进

class GuidedFilter(nn.Module):
    def __init__(self, r, eps=1e-8, height=512, width=512):
        super(GuidedFilter, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)
        self.N = 1/self.boxfilter(torch.ones((1, 3, height, width)))

    def forward(self, x):
        n_x, c_x, h_x, w_x = x.size()
        assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1


        # mean_x
        self.N = self.N.to(x.device)
        mean_x = self.boxfilter(x) * self.N

        # cov_xy,也即var_x
        cov_xy = self.boxfilter(x * x) * self.N - mean_x * mean_x

        # 这里是求A和b的重点
        # A
        A = cov_xy / (cov_xy + self.eps)
        # b
        b = (1.0 - A) * mean_x

        # mean_A; mean_b
        mean_A = self.boxfilter(A) * self.N
        mean_b = self.boxfilter(b) * self.N
        out = mean_A * x + mean_b
        return out