cls1 basic
基本写法
import numpy as npimport torchfrom torch import nnfrom detectron2.data.build import TrainingSamplerfrom pyxlpr.d2 import D2Trainerclass ValueCls(D2Trainer):""" 输入数据是3个[1, 10]的整数,做分类任务,判断和是否在(10,20)之间 """@classmethoddef build_model(cls, cfg):""" 自定义模型 """device = torch.device(cfg.MODEL.DEVICE)class ParityModel(nn.Module):def __init__(self):super().__init__()self.classifier = nn.Sequential(nn.Linear(in_features=3, out_features=2),nn.Sigmoid(),nn.Linear(in_features=2, out_features=2),nn.Sigmoid(),nn.Linear(in_features=2, out_features=1),nn.Sigmoid(),)self.criteon = nn.BCELoss() # 二分类交叉熵损失def forward(self, batched_inputs):x = batched_inputs[0].type(torch.FloatTensor).to(device)logits = self.classifier(x)if self.training:y = batched_inputs[1].unsqueeze(-1).type(torch.FloatTensor).to(device)loss = self.criteon(logits, y)return {'loss': loss} # 损失要以字典的形式返回else:y_hat = (logits > 0.5).type(torch.int)return y_hatreturn ParityModel().to(device)@classmethoddef build_train_loader(cls, cfg):""" 自定义训练集 """n = 5000 # 数据量data = np.random.randint(1, 10, [n, 3]) # 数据dataloader = torch.utils.data.DataLoader([(x, 10 < sum(x) < 20) for x in data],sampler=TrainingSampler(n), # 无限取用的数据流,避免一个epoch完迭代结束batch_size=cfg.SOLVER.IMS_PER_BATCH)return dataloader@classmethoddef build_test_loader(cls, cfg, dataset_name):""" 自定义验证集 """n = 100data = np.random.randint(1, 10, [n, 3])dataloader = torch.utils.data.DataLoader([(x, 10 < sum(x) < 20) for x in data],batch_size=cfg.SOLVER.IMS_PER_BATCH)return dataloader@classmethoddef build_evaluator(cls, cfg, dataset_name, output_folder=None):""" 自定义测评方法 """from detectron2.evaluation import DatasetEvaluatorclass Eval(DatasetEvaluator):def reset(self):self.acc = 0 # 分类正确的样本数self.total = 0 # 总样本数def process(self, inputs, outputs):""":param inputs: dataloader每次batch得到的内容结构:param outputs: model.inference对应的返回值:return:"""self.total += len(outputs)self.acc += int(sum(inputs[1].unsqueeze(-1) == outputs)[0])def evaluate(self):# 结果要以字典的形式返回,第一层是任务名,第二层是多个指标结果return {'Classification': {'Accuracy': self.acc / self.total, 'total': self.total}}return Eval()CONFIGS = r"""MODEL:DEVICE: 'cpu'SOLVER:IMS_PER_BATCH: 16MAX_ITER: 3125BASE_LR: 0.01OUTPUT_DIR: "/home/chenkunze/det2/valuecls""""if __name__ == "__main__":np.random.seed(4101)torch.manual_seed(4102)ValueCls(CONFIGS).launch()# 目前精度:0.6100# 框架不一样,相比于原生pytorch,跑出来确实有些不同吧
cls2 enhance
加强版
