cls1 basic

基本写法

  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from detectron2.data.build import TrainingSampler
  5. from pyxlpr.d2 import D2Trainer
  6. class ValueCls(D2Trainer):
  7. """ 输入数据是3个[1, 10]的整数,做分类任务,判断和是否在(10,20)之间 """
  8. @classmethod
  9. def build_model(cls, cfg):
  10. """ 自定义模型 """
  11. device = torch.device(cfg.MODEL.DEVICE)
  12. class ParityModel(nn.Module):
  13. def __init__(self):
  14. super().__init__()
  15. self.classifier = nn.Sequential(
  16. nn.Linear(in_features=3, out_features=2),
  17. nn.Sigmoid(),
  18. nn.Linear(in_features=2, out_features=2),
  19. nn.Sigmoid(),
  20. nn.Linear(in_features=2, out_features=1),
  21. nn.Sigmoid(),
  22. )
  23. self.criteon = nn.BCELoss() # 二分类交叉熵损失
  24. def forward(self, batched_inputs):
  25. x = batched_inputs[0].type(torch.FloatTensor).to(device)
  26. logits = self.classifier(x)
  27. if self.training:
  28. y = batched_inputs[1].unsqueeze(-1).type(torch.FloatTensor).to(device)
  29. loss = self.criteon(logits, y)
  30. return {'loss': loss} # 损失要以字典的形式返回
  31. else:
  32. y_hat = (logits > 0.5).type(torch.int)
  33. return y_hat
  34. return ParityModel().to(device)
  35. @classmethod
  36. def build_train_loader(cls, cfg):
  37. """ 自定义训练集 """
  38. n = 5000 # 数据量
  39. data = np.random.randint(1, 10, [n, 3]) # 数据
  40. dataloader = torch.utils.data.DataLoader([(x, 10 < sum(x) < 20) for x in data],
  41. sampler=TrainingSampler(n), # 无限取用的数据流,避免一个epoch完迭代结束
  42. batch_size=cfg.SOLVER.IMS_PER_BATCH)
  43. return dataloader
  44. @classmethod
  45. def build_test_loader(cls, cfg, dataset_name):
  46. """ 自定义验证集 """
  47. n = 100
  48. data = np.random.randint(1, 10, [n, 3])
  49. dataloader = torch.utils.data.DataLoader([(x, 10 < sum(x) < 20) for x in data],
  50. batch_size=cfg.SOLVER.IMS_PER_BATCH)
  51. return dataloader
  52. @classmethod
  53. def build_evaluator(cls, cfg, dataset_name, output_folder=None):
  54. """ 自定义测评方法 """
  55. from detectron2.evaluation import DatasetEvaluator
  56. class Eval(DatasetEvaluator):
  57. def reset(self):
  58. self.acc = 0 # 分类正确的样本数
  59. self.total = 0 # 总样本数
  60. def process(self, inputs, outputs):
  61. """
  62. :param inputs: dataloader每次batch得到的内容结构
  63. :param outputs: model.inference对应的返回值
  64. :return:
  65. """
  66. self.total += len(outputs)
  67. self.acc += int(sum(inputs[1].unsqueeze(-1) == outputs)[0])
  68. def evaluate(self):
  69. # 结果要以字典的形式返回,第一层是任务名,第二层是多个指标结果
  70. return {'Classification': {'Accuracy': self.acc / self.total, 'total': self.total}}
  71. return Eval()
  72. CONFIGS = r"""
  73. MODEL:
  74. DEVICE: 'cpu'
  75. SOLVER:
  76. IMS_PER_BATCH: 16
  77. MAX_ITER: 3125
  78. BASE_LR: 0.01
  79. OUTPUT_DIR: "/home/chenkunze/det2/valuecls"
  80. """
  81. if __name__ == "__main__":
  82. np.random.seed(4101)
  83. torch.manual_seed(4102)
  84. ValueCls(CONFIGS).launch()
  85. # 目前精度:0.6100
  86. # 框架不一样,相比于原生pytorch,跑出来确实有些不同吧

cls2 enhance

加强版