对抗学习的基本思路就扰动模型来提升模型的抗噪能力,增强鲁棒性。
使用下面的例子测试
开源实验代码:https://github.com/Ricardokevins/NLP/tree/main/TextClassification可直接实验运行
初步实验的结果
| Model | Acc without FGM | Acc with FGM |
|---|---|---|
| BiLSTM | 0.71875 | 0.70486 |
| BiLSTM+Attention1 | 0.77778 | 0.78125 |
| BiLSTM+Attention2 | 0.74306 | 0.77431 |
| Transformer | 0.77083 | 0.71875 |
可以看到有的是有提升的,但是也有退步的,或许真的炸炉了,或许是没有收敛,也有可能是测试数据太少了
import torchclass FGM():def __init__(self, model):self.model = modelself.backup = {}def attack(self, epsilon=1., emb_name='embedding'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='embedding'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:assert name in self.backupparam.data = self.backup[name]self.backup = {}
Attack_with_FGM = Truedef train_with_FGM():net.train()fgm = FGM(net)#optimizer = optim.SGD(net.parameters(), lr=0.01,weight_decay=0.01)#optimizer = optim.Adam(net.parameters(), lr=learning_rate,weight_decay=0)optimizer = AdamW(net.parameters(),lr = 2e-4, eps = 1e-8)#optimizer = AdamW(net.parameters(), lr=learning_rate)#train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])train_iter = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)pre_loss=1crossentropyloss = nn.CrossEntropyLoss()total_steps = len(train_iter)*num_epochsprint("----total step: ",total_steps,"----")print("----warmup step: ", int(total_steps * 0.2), "----")best_acc = 0scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = int(total_steps*0.15), num_training_steps = total_steps)for epoch in range(num_epochs):correct = 0total=0iter = 0pbar = ProgressBar(n_total=len(train_iter), desc='Training')net.train()avg_loss = 0for train_text,label in train_iter:iter += 1if train_text.size(0) != batch_size:breaktrain_text = train_text.reshape(batch_size, -1)label = label.reshape(-1)if USE_CUDA:train_text=train_text.cuda()label = label.cuda()logits = net(train_text)loss = crossentropyloss(logits, label)loss.backward()avg_loss += loss.item()fgm.attack()logits = net(train_text)loss_adv = crossentropyloss(logits, label)loss_adv.backward()fgm.restore()optimizer.step()scheduler.step()optimizer.zero_grad()#pbar(iter, {'loss': avg_loss/iter})_, logits = torch.max(logits, 1)correct += logits.data.eq(label.data).cpu().sum()total += batch_sizeloss=loss.detach().cpu()#print("\nepoch ", str(epoch)," loss: ", loss.mean().numpy().tolist(),"Acc:", correct.numpy().tolist()/total)cur_acc = test()if best_acc < cur_acc:best_acc = cur_accprint(best_acc)return
