对抗学习的基本思路就扰动模型来提升模型的抗噪能力,增强鲁棒性。
使用下面的例子测试
开源实验代码:https://github.com/Ricardokevins/NLP/tree/main/TextClassification可直接实验运行
初步实验的结果
Model | Acc without FGM | Acc with FGM |
---|---|---|
BiLSTM | 0.71875 | 0.70486 |
BiLSTM+Attention1 | 0.77778 | 0.78125 |
BiLSTM+Attention2 | 0.74306 | 0.77431 |
Transformer | 0.77083 | 0.71875 |
可以看到有的是有提升的,但是也有退步的,或许真的炸炉了,或许是没有收敛,也有可能是测试数据太少了
import torch
class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='embedding'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='embedding'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
Attack_with_FGM = True
def train_with_FGM():
net.train()
fgm = FGM(net)
#optimizer = optim.SGD(net.parameters(), lr=0.01,weight_decay=0.01)
#optimizer = optim.Adam(net.parameters(), lr=learning_rate,weight_decay=0)
optimizer = AdamW(net.parameters(),lr = 2e-4, eps = 1e-8)
#optimizer = AdamW(net.parameters(), lr=learning_rate)
#train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
pre_loss=1
crossentropyloss = nn.CrossEntropyLoss()
total_steps = len(train_iter)*num_epochs
print("----total step: ",total_steps,"----")
print("----warmup step: ", int(total_steps * 0.2), "----")
best_acc = 0
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = int(total_steps*0.15), num_training_steps = total_steps)
for epoch in range(num_epochs):
correct = 0
total=0
iter = 0
pbar = ProgressBar(n_total=len(train_iter), desc='Training')
net.train()
avg_loss = 0
for train_text,label in train_iter:
iter += 1
if train_text.size(0) != batch_size:
break
train_text = train_text.reshape(batch_size, -1)
label = label.reshape(-1)
if USE_CUDA:
train_text=train_text.cuda()
label = label.cuda()
logits = net(train_text)
loss = crossentropyloss(logits, label)
loss.backward()
avg_loss += loss.item()
fgm.attack()
logits = net(train_text)
loss_adv = crossentropyloss(logits, label)
loss_adv.backward()
fgm.restore()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
#pbar(iter, {'loss': avg_loss/iter})
_, logits = torch.max(logits, 1)
correct += logits.data.eq(label.data).cpu().sum()
total += batch_size
loss=loss.detach().cpu()
#print("\nepoch ", str(epoch)," loss: ", loss.mean().numpy().tolist(),"Acc:", correct.numpy().tolist()/total)
cur_acc = test()
if best_acc < cur_acc:
best_acc = cur_acc
print(best_acc)
return