PyTorchXGBoost
在大规模数据集进行读取进行训练的过程中,迭代读取数据集是一个非常合适的选择,在Pytorch中支持迭代读取的方式。接下来将介绍XGBoost的迭代读取的方式。

内存数据读取

  1. class IterLoadForDMatrix(xgb.core.DataIter):
  2. def __init__(self, df=None, features=None, target=None, batch_size=256 * 1024):
  3. self.features = features
  4. self.target = target
  5. self.df = df
  6. self.batch_size = batch_size
  7. self.batches = int(np.ceil(len(df) / self.batch_size))
  8. self.it = 0 # set iterator to 0
  9. super().__init__()
  10. def reset(self):
  11. '''Reset the iterator'''
  12. self.it = 0
  13. def next(self, input_data):
  14. '''Yield next batch of data.'''
  15. if self.it == self.batches:
  16. return 0 # Return 0 when there's no more batch.
  17. a = self.it * self.batch_size
  18. b = min((self.it + 1) * self.batch_size, len(self.df))
  19. dt = pd.DataFrame(self.df.iloc[a:b])
  20. input_data(data=dt[self.features], label=dt[self.target]) # , weight=dt['weight'])
  21. self.it += 1
  22. return 1

调用方法(此种方式比较适合GPU训练):

  1. Xy_train = IterLoadForDMatrix(train.loc[train_idx], FEATURES, 'target')
  2. dtrain = xgb.DeviceQuantileDMatrix(Xy_train, max_bin=256)

参考文档:https://xgboost.readthedocs.io/en/latest/python/examples/quantile_data_iterator.html

外部数据迭代读取

  1. class Iterator(xgboost.DataIter):
  2. def __init__(self, svm_file_paths: List[str]):
  3. self._file_paths = svm_file_paths
  4. self._it = 0
  5. super().__init__(cache_prefix=os.path.join(".", "cache"))
  6. def next(self, input_data: Callable):
  7. if self._it == len(self._file_paths):
  8. # return 0 to let XGBoost know this is the end of iteration
  9. return 0
  10. X, y = load_svmlight_file(self._file_paths[self._it])
  11. input_data(X, y)
  12. self._it += 1
  13. return 1
  14. def reset(self):
  15. """Reset the iterator to its beginning"""
  16. self._it = 0

调用方法(此种方式比较适合CPU训练):

  1. it = Iterator(["file_0.svm", "file_1.svm", "file_2.svm"])
  2. Xy = xgboost.DMatrix(it)
  3. # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
  4. # as noted in following sections.
  5. booster = xgboost.train({"tree_method": "approx"}, Xy)

参考文档:https://xgboost.readthedocs.io/en/stable/tutorials/external_memory.html