1.定义网络基本单元

  1. # 残差单元
  2. class Residual(nn.Module):
  3. def __init__(self, fn):
  4. super().__init__()
  5. self.fn = fn
  6. def forward(self, x, **kwargs):
  7. # Resudual connect: fn(x) + x
  8. return self.fn(x, **kwargs) + x
  9. # 层归一化
  10. class PreNorm(nn.Module):
  11. def __init__(self, dim, fn):
  12. super().__init__()
  13. self.norm = nn.LayerNorm(dim)
  14. self.fn = fn
  15. def forward(self, x, **kwargs):
  16. # using Layer Normalization before input to fn layer
  17. return self.fn(self.norm(x), **kwargs)
  18. # 前馈网络
  19. class FeedForward(nn.Module):
  20. # Feed Forward Neural Network
  21. def __init__(self, dim, hidden_dim, dropout=0.):
  22. super().__init__()
  23. # Two linear network with GELU and Dropout
  24. self.net = nn.Sequential(
  25. nn.Linear(dim, hidden_dim),
  26. nn.GELU(),
  27. nn.Dropout(dropout),
  28. nn.Linear(hidden_dim, dim),
  29. nn.Dropout(dropout)
  30. )
  31. def forward(self, x):
  32. return self.net(x)

2.多头自注意力机制

  1. class Attention(nn.Module):
  2. def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
  3. super().__init__()
  4. inner_dim = dim_head * heads
  5. self.heads = heads
  6. self.scale = dim_head ** -0.5
  7. # get q,k,v from a single weight matrix
  8. self.to_q = nn.Linear(dim, inner_dim, bias=False)
  9. self.to_k = nn.Linear(dim, inner_dim, bias=False)
  10. self.to_v = nn.Linear(dim, inner_dim, bias=False)
  11. self.to_out = nn.Sequential(
  12. nn.Linear(inner_dim, dim),
  13. nn.Dropout(dropout)
  14. )
  15. def forward(self, x):
  16. b, _, n, _, h = *x.shape, self.heads
  17. hsi = x[:, 0, :, :]
  18. lidar = x[:, 1, :, :]
  19. q = self.to_q(hsi)
  20. k = self.to_k(hsi)
  21. v = self.to_v(lidar)
  22. # split q,k,v from [batch, patch_num, head_num*head_dim] -> [batch, head_num, patch_num, head_dim]
  23. q = rearrange(q, 'b n (h d) -> b h n d', h=h)
  24. k = rearrange(k, 'b n (h d) -> b h n d', h=h)
  25. v = rearrange(v, 'b n (h d) -> b h n d', h=h)
  26. # transpose(k) * q / sqrt(head_dim) -> [batch, head_num, patch_num, patch_num]
  27. dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
  28. # softmax normalization -> attention matrix
  29. attn = dots.softmax(dim=-1)
  30. # value * attention matrix -> output
  31. out = torch.einsum('bhij,bhjd->bhid', attn, v)
  32. # concat all output -> [batch, patch_num, head_num*head_dim]
  33. out = rearrange(out, 'b h n d -> b n (h d)')
  34. # Linear + Dropout
  35. out = self.to_out(out)
  36. out = out.unsqueeze(1)
  37. out = torch.cat([out, out], dim=1)
  38. # out: [batch, patch_num, embedding_dim]
  39. return out

3.Transformer

  1. class Transformer(nn.Module):
  2. def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
  3. super().__init__()
  4. self.layers = nn.ModuleList([])
  5. for _ in range(depth):
  6. # using multi-self-attention and feed forward neural network repeatly
  7. self.layers.append(nn.ModuleList([
  8. Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))),
  9. Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))
  10. ]))
  11. def forward(self, x):
  12. for attn, ff in self.layers:
  13. x = attn(x)
  14. x = ff(x)
  15. return x

4.VIT 特征融合器

class ViT_spa(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size

        self.hsi_pos = nn.Parameter(torch.randn(1, num_patches, dim))
        self.lidar_pos = nn.Parameter(torch.randn(1, num_patches, dim))

        self.hsi_to_embedding = nn.Linear(patch_dim, dim)
        self.lidar_to_embedding = nn.Linear(patch_dim, dim)

        self.hsi_dropout = nn.Dropout(emb_dropout)
        self.lidar_dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.embedding_to_hsi = nn.Linear(dim, patch_dim)
        self.embedding_to_lidar = nn.Linear(dim, patch_dim)

    def forward(self, hsi, lidar):
        p = self.patch_size
        _, _, h, w = hsi.shape

        hh = int(h / p)

        hsi_embed = rearrange(hsi, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        lidar_embed = rearrange(lidar, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

        hsi_embed = self.hsi_to_embedding(hsi_embed)
        lidar_embed = self.lidar_to_embedding(lidar_embed)

        _, n, _ = hsi_embed.shape

        hsi_embed += self.hsi_pos[:, :n]
        lidar_embed += self.lidar_pos[:, :n]

        hsi_embed = self.hsi_dropout(hsi_embed)
        lidar_embed = self.lidar_dropout(lidar_embed)

        hsi_embed = hsi_embed.unsqueeze(1)
        lidar_embed = lidar_embed.unsqueeze(1)

        x = torch.cat([hsi_embed, lidar_embed], dim=1)

        x = self.transformer(x)

        hsi = x[:, 0, :, :]
        lidar = x[:, 1, :, :]

        hsi = hsi.squeeze(1)
        lidar = lidar.squeeze(1)

        hsi = self.embedding_to_hsi(hsi)
        lidar = self.embedding_to_lidar(lidar)

        hsi = rearrange(hsi, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=hh, p1=p, p2=p)
        lidar = rearrange(lidar, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=hh, p1=p, p2=p)

        return hsi, lidar

5.定义并实现网络模型

2.svg

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.hsi_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=30, out_channels=64, kernel_size=3, padding=0), 
            nn.BatchNorm2d(num_features=64), 
            nn.ReLU(inplace=True)
        )
        self.hsi_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0), 
            nn.BatchNorm2d(num_features=128), 
            nn.ReLU(inplace=True)
        )
        self.hsi_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=0), 
            nn.BatchNorm2d(num_features=256), 
            nn.ReLU(inplace=True)
        )

        self.lidar_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,padding=0), 
            nn.BatchNorm2d(num_features=64), 
            nn.ReLU(inplace=True)
        )
        self.lidar_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=0),
            nn.BatchNorm2d(num_features=128), 
            nn.ReLU(inplace=True)
        )
        self.lidar_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=0), 
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )

        self.hsifuse_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=30, out_channels=64, kernel_size=3, padding=0), 
            nn.BatchNorm2d(num_features=64), 
            nn.ReLU(inplace=True)
        )
        self.hsifuse_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0), 
            nn.BatchNorm2d(num_features=128), 
            nn.ReLU(inplace=True)
        )
        self.hsifuse_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=0), 
            nn.BatchNorm2d(num_features=256), 
            nn.ReLU(inplace=True)
        )

        self.lidarfuse_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,padding=0), 
            nn.BatchNorm2d(num_features=64), 
            nn.ReLU(inplace=True)
        )
        self.lidarfuse_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=0),
            nn.BatchNorm2d(num_features=128), 
            nn.ReLU(inplace=True)
        )
        self.lidarfuse_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=0), 
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )

        # self.vit_spa = ViT_spa(image_size=9, patch_size=3, dim=256, depth=2, heads=4, mlp_dim=256, pool='mean', channels=128, dim_head=64, dropout=0., emb_dropout=0.)

        self.vit_spa1 = ViT_spa(image_size=9, patch_size=1, dim=128, depth=2, heads=4, mlp_dim=256, channels=64, dim_head=32, dropout=0., emb_dropout=0)
        self.vit_spa2 = ViT_spa(image_size=7, patch_size=1, dim=256, depth=2, heads=4, mlp_dim=512, channels=128, dim_head=64, dropout=0., emb_dropout=0)
        self.vit_spa3 = ViT_spa(image_size=5, patch_size=1, dim=512, depth=2, heads=4, mlp_dim=1024, channels=256, dim_head=128, dropout=0., emb_dropout=0)

        #Dropout随机失活概率为0.6
        self.drop1 = nn.Dropout(0.6)
        self.drop2 = nn.Dropout(0.6)
        self.drop3 = nn.Dropout(0.6)

        self.fusionlinear_1   = nn.Linear(in_features= 1280,out_features= 15)
        self.fusionlinear_2 = nn.Linear(in_features= 1280,out_features= 15)
        self.fusionlinear_3 = nn.Linear(in_features= 2560, out_features= 15)

        self.weight = nn.Parameter(torch.ones(2))

    def forward(self, hsi, lidar):

        hsi_stage1 = self.hsi_conv1(hsi)
        lidar_stage1 = self.lidar_conv1(lidar)

        hsi_stage2 = self.hsi_conv2(hsi_stage1)
        lidar_stage2 = self.lidar_conv2(lidar_stage1)

        hsi_stage3 = self.hsi_conv3(hsi_stage2)
        lidar_stage3 = self.lidar_conv3(lidar_stage2)

        hsifuse_stage1 = self.hsifuse_conv1(hsi)
        lidarfuse_stage1 = self.lidarfuse_conv1(lidar)

        hsifuse1, lidarfuse1 = self.vit_spa1(hsifuse_stage1, lidarfuse_stage1)
        hsifuse_stage1 = hsifuse_stage1 + hsifuse1;
        lidarfuse_stage1 = lidarfuse_stage1 + lidarfuse1;

        hsifuse_stage2 = self.hsifuse_conv2(hsifuse_stage1)
        lidarfuse_stage2 = self.lidarfuse_conv2(lidarfuse_stage1)

        hsifuse2, lidarfuse2 = self.vit_spa2(hsifuse_stage2, lidarfuse_stage2)
        hsifuse_stage2 = hsifuse_stage2 + hsifuse2;
        lidarfuse_stage2 = lidarfuse_stage2 + lidarfuse2;

        hsifuse_stage3 = self.hsifuse_conv3(hsifuse_stage2)
        lidarfuse_stage3 = self.lidarfuse_conv3(lidarfuse_stage2)

        hsifuse3, lidarfuse3 = self.vit_spa3(hsifuse_stage3, lidarfuse_stage3)
        hsifuse_stage3 = hsifuse_stage3 + hsifuse3;
        lidarfuse_stage3 = lidarfuse_stage3 + lidarfuse3;

        fuse_feture = torch.cat((hsifuse_stage3, lidarfuse_stage3), dim=1)

        hsi_feature = hsi_stage3.reshape(-1, hsi_stage3.shape[1], hsi_stage3.shape[2]*hsi_stage3.shape[3])
        lidar_feature = lidar_stage3.reshape(-1, lidar_stage3.shape[1], lidar_stage3.shape[2]*lidar_stage3.shape[3])
        fuse_feature = fuse_feture.reshape(-1, fuse_feture.shape[1], fuse_feture.shape[2]*fuse_feture.shape[3])

        hsi_feature = F.max_pool1d(hsi_feature, kernel_size=5)
        hsi_feature = hsi_feature.reshape(-1, hsi_feature.shape[1] * hsi_feature.shape[2])
        lidar_feature = F.max_pool1d(lidar_feature, kernel_size=5)
        lidar_feature = lidar_feature.reshape(-1, lidar_feature.shape[1] * lidar_feature.shape[2])
        fuse_feature = F.max_pool1d(fuse_feature, kernel_size=5)
        fuse_feature = fuse_feature.reshape(-1, fuse_feature.shape[1] * fuse_feature.shape[2])

        hsi_feature = self.drop1(hsi_feature)
        lidar_feature = self.drop2(lidar_feature)
        fuse_feature = self.drop3(fuse_feature)

        output_hsi = self.fusionlinear_1(hsi_feature)
        output_lidar = self.fusionlinear_2(lidar_feature)
        output_fuse = self.fusionlinear_3(fuse_feature)

        weight = torch.sigmoid(self.weight)
        outputs = weight[0] * output_hsi + weight[1] * output_lidar + output_fuse
        return outputs

net = Net().to(device)

6.训练结果

截屏2022-03-20 18.39.10.png