关键词:Haar, Wavlet

Haar小波的基本原理

梳理自

如数组:[2,4,6,8,10,12,14,16]
求均值(整体信息):[3,7,11,15]; 求差值(细节信息):[-1,-1,-1,-1] —> 因此,Harr小波分解后,尺寸会变成原来的一半
image.png

二维

对于二维haar小波,(1)首先,沿着矩阵的每一行做一维的Haar变换;(2)然后,沿着矩阵的每一列做一维的哈尔变换;(3)对于每个低频分量矩阵(近似信息)重复步骤(1)和(2)直到完成指定的等级划分。过程如下:
image.png

图中的低频和高频就对应着一维中的整体和细节
行分解和列分解的顺序是可以互换的,保持一致即可。

Haar小波实现

我是基于Deep Learning做图像生成的,推荐一个库pytorch_wavlet。
https://pytorch-wavelets.readthedocs.io/en/latest/readme.html
https://github.com/fbcotter/pytorch_wavelets

自带的example如下:

  1. import torch
  2. from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
  3. xfm = DWTForward(J=3, mode='zero', wave='db3') # Accepts all wave types available to PyWavelets
  4. ifm = DWTInverse(mode='zero', wave='db3')
  5. X = torch.randn(10,5,64,64)
  6. Yl, Yh = xfm(X)
  7. print(Yl.shape)
  8. >>> torch.Size([10, 5, 12, 12])
  9. print(Yh[0].shape)
  10. >>> torch.Size([10, 5, 3, 34, 34])
  11. print(Yh[1].shape)
  12. >>> torch.Size([10, 5, 3, 19, 19])
  13. print(Yh[2].shape)
  14. >>> torch.Size([10, 5, 3, 12, 12])
  15. Y = ifm((Yl, Yh))
  16. import numpy as np
  17. np.testing.assert_array_almost_equal(Y.cpu().numpy(), X.cpu().numpy())

我使用的过程中,(1)会先做小波分解,把一个尺寸的图像拼接起来,变成pytorch的[batch_size, channel, H, W]形式;(2)处理完毕后,将各个尺寸的分频结果再做小波逆变换合成原尺寸图像

因此,我会将每张个尺寸的图像保存成字典,使用方法如下:

  1. # 分解过程
  2. def get_wavlet(img):
  3. '''
  4. 我期望得到一个字典,有4个key,分别为stride=8的12通道,stride=4的9通道,stride=2的9通道, 原图
  5. s8: torch.Size([4, 12, 64, 64])
  6. s4: torch.Size([4, 9, 128, 128])
  7. s2: torch.Size([4, 9, 256, 256])
  8. s1: torch.Size([4, 3, 512, 512])
  9. '''
  10. from pytorch_wavelets import DWTForward
  11. xfm = DWTForward(J=3, mode='zero', wave='haar').cuda()
  12. Yl, Yh = xfm(img)
  13. Yh_retensors = {}
  14. for idx, tmp in enumerate(Yh):
  15. stride = 2 ** (idx + 1)
  16. Yh_retensors[f's{stride}'] = tmp.reshape(tmp.shape[0], -1, tmp.shape[3], tmp.shape[4])
  17. tensors = {}
  18. tensors['s8'] = torch.cat([Yl, Yh_retensors['s8']], dim=1)
  19. tensors['s4'] = Yh_retensors['s4']
  20. tensors['s2'] = Yh_retensors['s2']
  21. tensors['s1'] = img
  22. return tensors
  23. # 还原过程:
  24. def IDWT(tensors):
  25. '''
  26. tensors是个字典,包含stride=8的12通道,stride=4的9通道,stride=2的9通道
  27. s8: torch.Size([4, 12, 64, 64])
  28. s4: torch.Size([4, 9, 128, 128])
  29. s2: torch.Size([4, 9, 256, 256])
  30. 需要将它还原是DWTInverse函数需要的形式,并最终生成原图
  31. '''
  32. Yl = tensors['s8'][:, :3]
  33. Yh = [tensors['s2'], tensors['s4'], tensors['s8'][:, 3:]]
  34. for idx in range(len(Yh)):
  35. Yh[idx] = Yh[idx].reshape(Yh[idx].shape[0], 3, -1, Yh[idx].shape[2], Yh[idx].shape[3])
  36. ifm = DWTInverse(mode='zero', wave='haar').cuda()
  37. output = ifm((Yl, Yh))
  38. return output