1.定义网络基本单元
# 残差单元class Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, **kwargs):# Resudual connect: fn(x) + xreturn self.fn(x, **kwargs) + x# 层归一化class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):# using Layer Normalization before input to fn layerreturn self.fn(self.norm(x), **kwargs)# 前馈网络class FeedForward(nn.Module):# Feed Forward Neural Networkdef __init__(self, dim, hidden_dim, dropout=0.):super().__init__()# Two linear network with GELU and Dropoutself.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)
2.多头自注意力机制
class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsself.heads = headsself.scale = dim_head ** -0.5# get q,k,v from a single weight matrixself.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_k = nn.Linear(dim, inner_dim, bias=False)self.to_v = nn.Linear(dim, inner_dim, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout))def forward(self, x):b, _, n, _, h = *x.shape, self.headshsi = x[:, 0, :, :]lidar = x[:, 1, :, :]q = self.to_q(hsi)k = self.to_k(hsi)v = self.to_v(lidar)# split q,k,v from [batch, patch_num, head_num*head_dim] -> [batch, head_num, patch_num, head_dim]q = rearrange(q, 'b n (h d) -> b h n d', h=h)k = rearrange(k, 'b n (h d) -> b h n d', h=h)v = rearrange(v, 'b n (h d) -> b h n d', h=h)# transpose(k) * q / sqrt(head_dim) -> [batch, head_num, patch_num, patch_num]dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale# softmax normalization -> attention matrixattn = dots.softmax(dim=-1)# value * attention matrix -> outputout = torch.einsum('bhij,bhjd->bhid', attn, v)# concat all output -> [batch, patch_num, head_num*head_dim]out = rearrange(out, 'b h n d -> b n (h d)')# Linear + Dropoutout = self.to_out(out)out = out.unsqueeze(1)out = torch.cat([out, out], dim=1)# out: [batch, patch_num, embedding_dim]return out
3.Transformer
class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):# using multi-self-attention and feed forward neural network repeatlyself.layers.append(nn.ModuleList([Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))),Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))]))def forward(self, x):for attn, ff in self.layers:x = attn(x)x = ff(x)return x
4.VIT 特征融合器
class ViT_spa(nn.Module):
def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels=3,
dim_head=64, dropout=0., emb_dropout=0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.hsi_pos = nn.Parameter(torch.randn(1, num_patches, dim))
self.lidar_pos = nn.Parameter(torch.randn(1, num_patches, dim))
self.hsi_to_embedding = nn.Linear(patch_dim, dim)
self.lidar_to_embedding = nn.Linear(patch_dim, dim)
self.hsi_dropout = nn.Dropout(emb_dropout)
self.lidar_dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.embedding_to_hsi = nn.Linear(dim, patch_dim)
self.embedding_to_lidar = nn.Linear(dim, patch_dim)
def forward(self, hsi, lidar):
p = self.patch_size
_, _, h, w = hsi.shape
hh = int(h / p)
hsi_embed = rearrange(hsi, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
lidar_embed = rearrange(lidar, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
hsi_embed = self.hsi_to_embedding(hsi_embed)
lidar_embed = self.lidar_to_embedding(lidar_embed)
_, n, _ = hsi_embed.shape
hsi_embed += self.hsi_pos[:, :n]
lidar_embed += self.lidar_pos[:, :n]
hsi_embed = self.hsi_dropout(hsi_embed)
lidar_embed = self.lidar_dropout(lidar_embed)
hsi_embed = hsi_embed.unsqueeze(1)
lidar_embed = lidar_embed.unsqueeze(1)
x = torch.cat([hsi_embed, lidar_embed], dim=1)
x = self.transformer(x)
hsi = x[:, 0, :, :]
lidar = x[:, 1, :, :]
hsi = hsi.squeeze(1)
lidar = lidar.squeeze(1)
hsi = self.embedding_to_hsi(hsi)
lidar = self.embedding_to_lidar(lidar)
hsi = rearrange(hsi, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=hh, p1=p, p2=p)
lidar = rearrange(lidar, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=hh, p1=p, p2=p)
return hsi, lidar
5.定义并实现网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hsi_conv1 = nn.Sequential(
nn.Conv2d(in_channels=30, out_channels=64, kernel_size=3, padding=0),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True)
)
self.hsi_conv2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True)
)
self.hsi_conv3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=0),
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True)
)
self.lidar_conv1 = nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,padding=0),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True)
)
self.lidar_conv2 = nn.Sequential(
nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=0),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True)
)
self.lidar_conv3 = nn.Sequential(
nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=0),
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True)
)
self.hsifuse_conv1 = nn.Sequential(
nn.Conv2d(in_channels=30, out_channels=64, kernel_size=3, padding=0),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True)
)
self.hsifuse_conv2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True)
)
self.hsifuse_conv3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=0),
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True)
)
self.lidarfuse_conv1 = nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,padding=0),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True)
)
self.lidarfuse_conv2 = nn.Sequential(
nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=0),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True)
)
self.lidarfuse_conv3 = nn.Sequential(
nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=0),
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True)
)
# self.vit_spa = ViT_spa(image_size=9, patch_size=3, dim=256, depth=2, heads=4, mlp_dim=256, pool='mean', channels=128, dim_head=64, dropout=0., emb_dropout=0.)
self.vit_spa1 = ViT_spa(image_size=9, patch_size=1, dim=128, depth=2, heads=4, mlp_dim=256, channels=64, dim_head=32, dropout=0., emb_dropout=0)
self.vit_spa2 = ViT_spa(image_size=7, patch_size=1, dim=256, depth=2, heads=4, mlp_dim=512, channels=128, dim_head=64, dropout=0., emb_dropout=0)
self.vit_spa3 = ViT_spa(image_size=5, patch_size=1, dim=512, depth=2, heads=4, mlp_dim=1024, channels=256, dim_head=128, dropout=0., emb_dropout=0)
#Dropout随机失活概率为0.6
self.drop1 = nn.Dropout(0.6)
self.drop2 = nn.Dropout(0.6)
self.drop3 = nn.Dropout(0.6)
self.fusionlinear_1 = nn.Linear(in_features= 1280,out_features= 15)
self.fusionlinear_2 = nn.Linear(in_features= 1280,out_features= 15)
self.fusionlinear_3 = nn.Linear(in_features= 2560, out_features= 15)
self.weight = nn.Parameter(torch.ones(2))
def forward(self, hsi, lidar):
hsi_stage1 = self.hsi_conv1(hsi)
lidar_stage1 = self.lidar_conv1(lidar)
hsi_stage2 = self.hsi_conv2(hsi_stage1)
lidar_stage2 = self.lidar_conv2(lidar_stage1)
hsi_stage3 = self.hsi_conv3(hsi_stage2)
lidar_stage3 = self.lidar_conv3(lidar_stage2)
hsifuse_stage1 = self.hsifuse_conv1(hsi)
lidarfuse_stage1 = self.lidarfuse_conv1(lidar)
hsifuse1, lidarfuse1 = self.vit_spa1(hsifuse_stage1, lidarfuse_stage1)
hsifuse_stage1 = hsifuse_stage1 + hsifuse1;
lidarfuse_stage1 = lidarfuse_stage1 + lidarfuse1;
hsifuse_stage2 = self.hsifuse_conv2(hsifuse_stage1)
lidarfuse_stage2 = self.lidarfuse_conv2(lidarfuse_stage1)
hsifuse2, lidarfuse2 = self.vit_spa2(hsifuse_stage2, lidarfuse_stage2)
hsifuse_stage2 = hsifuse_stage2 + hsifuse2;
lidarfuse_stage2 = lidarfuse_stage2 + lidarfuse2;
hsifuse_stage3 = self.hsifuse_conv3(hsifuse_stage2)
lidarfuse_stage3 = self.lidarfuse_conv3(lidarfuse_stage2)
hsifuse3, lidarfuse3 = self.vit_spa3(hsifuse_stage3, lidarfuse_stage3)
hsifuse_stage3 = hsifuse_stage3 + hsifuse3;
lidarfuse_stage3 = lidarfuse_stage3 + lidarfuse3;
fuse_feture = torch.cat((hsifuse_stage3, lidarfuse_stage3), dim=1)
hsi_feature = hsi_stage3.reshape(-1, hsi_stage3.shape[1], hsi_stage3.shape[2]*hsi_stage3.shape[3])
lidar_feature = lidar_stage3.reshape(-1, lidar_stage3.shape[1], lidar_stage3.shape[2]*lidar_stage3.shape[3])
fuse_feature = fuse_feture.reshape(-1, fuse_feture.shape[1], fuse_feture.shape[2]*fuse_feture.shape[3])
hsi_feature = F.max_pool1d(hsi_feature, kernel_size=5)
hsi_feature = hsi_feature.reshape(-1, hsi_feature.shape[1] * hsi_feature.shape[2])
lidar_feature = F.max_pool1d(lidar_feature, kernel_size=5)
lidar_feature = lidar_feature.reshape(-1, lidar_feature.shape[1] * lidar_feature.shape[2])
fuse_feature = F.max_pool1d(fuse_feature, kernel_size=5)
fuse_feature = fuse_feature.reshape(-1, fuse_feature.shape[1] * fuse_feature.shape[2])
hsi_feature = self.drop1(hsi_feature)
lidar_feature = self.drop2(lidar_feature)
fuse_feature = self.drop3(fuse_feature)
output_hsi = self.fusionlinear_1(hsi_feature)
output_lidar = self.fusionlinear_2(lidar_feature)
output_fuse = self.fusionlinear_3(fuse_feature)
weight = torch.sigmoid(self.weight)
outputs = weight[0] * output_hsi + weight[1] * output_lidar + output_fuse
return outputs
net = Net().to(device)
6.训练结果

