对抗学习的基本思路就扰动模型来提升模型的抗噪能力,增强鲁棒性。
    使用下面的例子测试
    开源实验代码:https://github.com/Ricardokevins/NLP/tree/main/TextClassification可直接实验运行
    初步实验的结果

    Model Acc without FGM Acc with FGM
    BiLSTM 0.71875 0.70486
    BiLSTM+Attention1 0.77778 0.78125
    BiLSTM+Attention2 0.74306 0.77431
    Transformer 0.77083 0.71875

    可以看到有的是有提升的,但是也有退步的,或许真的炸炉了,或许是没有收敛,也有可能是测试数据太少了

    1. import torch
    2. class FGM():
    3. def __init__(self, model):
    4. self.model = model
    5. self.backup = {}
    6. def attack(self, epsilon=1., emb_name='embedding'):
    7. # emb_name这个参数要换成你模型中embedding的参数名
    8. for name, param in self.model.named_parameters():
    9. if param.requires_grad and emb_name in name:
    10. self.backup[name] = param.data.clone()
    11. norm = torch.norm(param.grad)
    12. if norm != 0 and not torch.isnan(norm):
    13. r_at = epsilon * param.grad / norm
    14. param.data.add_(r_at)
    15. def restore(self, emb_name='embedding'):
    16. # emb_name这个参数要换成你模型中embedding的参数名
    17. for name, param in self.model.named_parameters():
    18. if param.requires_grad and emb_name in name:
    19. assert name in self.backup
    20. param.data = self.backup[name]
    21. self.backup = {}
    1. Attack_with_FGM = True
    2. def train_with_FGM():
    3. net.train()
    4. fgm = FGM(net)
    5. #optimizer = optim.SGD(net.parameters(), lr=0.01,weight_decay=0.01)
    6. #optimizer = optim.Adam(net.parameters(), lr=learning_rate,weight_decay=0)
    7. optimizer = AdamW(net.parameters(),lr = 2e-4, eps = 1e-8)
    8. #optimizer = AdamW(net.parameters(), lr=learning_rate)
    9. #train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    10. train_iter = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    11. pre_loss=1
    12. crossentropyloss = nn.CrossEntropyLoss()
    13. total_steps = len(train_iter)*num_epochs
    14. print("----total step: ",total_steps,"----")
    15. print("----warmup step: ", int(total_steps * 0.2), "----")
    16. best_acc = 0
    17. scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = int(total_steps*0.15), num_training_steps = total_steps)
    18. for epoch in range(num_epochs):
    19. correct = 0
    20. total=0
    21. iter = 0
    22. pbar = ProgressBar(n_total=len(train_iter), desc='Training')
    23. net.train()
    24. avg_loss = 0
    25. for train_text,label in train_iter:
    26. iter += 1
    27. if train_text.size(0) != batch_size:
    28. break
    29. train_text = train_text.reshape(batch_size, -1)
    30. label = label.reshape(-1)
    31. if USE_CUDA:
    32. train_text=train_text.cuda()
    33. label = label.cuda()
    34. logits = net(train_text)
    35. loss = crossentropyloss(logits, label)
    36. loss.backward()
    37. avg_loss += loss.item()
    38. fgm.attack()
    39. logits = net(train_text)
    40. loss_adv = crossentropyloss(logits, label)
    41. loss_adv.backward()
    42. fgm.restore()
    43. optimizer.step()
    44. scheduler.step()
    45. optimizer.zero_grad()
    46. #pbar(iter, {'loss': avg_loss/iter})
    47. _, logits = torch.max(logits, 1)
    48. correct += logits.data.eq(label.data).cpu().sum()
    49. total += batch_size
    50. loss=loss.detach().cpu()
    51. #print("\nepoch ", str(epoch)," loss: ", loss.mean().numpy().tolist(),"Acc:", correct.numpy().tolist()/total)
    52. cur_acc = test()
    53. if best_acc < cur_acc:
    54. best_acc = cur_acc
    55. print(best_acc)
    56. return