🚀 原文链接:https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html

In XGBoost 1.0.0, we introduced experimental support of using JSON for saving/loading XGBoost models and related hyper-parameters for training, aiming to replace the old binary internal format with an open format that can be easily reused. The support for binary format will be continued in the future until JSON format is no-longer experimental and has satisfying performance. This tutorial aims to share some basic insights into the JSON serialisation method used in XGBoost. Without explicitly(明确地) mentioned, the following sections assume you are using the JSON format, which can be enabled by providing the file name with .json as file extension when saving/loading model: booster.save_model('model.json'). More details below.

Before we get started, XGBoost is a gradient boosting library with focus on tree model, which means inside XGBoost, there are 2 distinct parts:

  1. The model consisting of trees
  2. Hyperparameters and configurations used for building the model.

If you come from Deep Learning community, then it should be clear to you that there are differences between the neural network structures composed of weights with fixed tensor operations, and the optimizers (like RMSprop(明确地)) used to train them.

So when one calls booster.save_model (xgb.save in R), XGBoost saves the trees, some model parameters like number of input columns in trained trees, and the objective function, which combined to represent the concept of “model” in XGBoost. As for why(至于为什么) are we saving the objective as part of model, that’s because objective controls transformation(变化) of global bias (called base_score in XGBoost). Users can share this model with others for prediction, evaluation or continue the training with a different set of hyper-parameters etc.

However, this is not the end of story. There are cases where we need to save something more than just the model itself. For example, in distrbuted training, XGBoost performs checkpointing operation. Or for some reasons, your favorite distributed computing framework decide to copy the model from one worker to another and continue the training in there. In such cases, the serialisation(序列化) output is required to contain enougth information to continue previous training without user providing any parameters again. We consider such scenario(设想) as memory snapshot (or memory based serialisation method) and distinguish it with normal model IO operation. Currently, memory snapshot is used in the following places:

  • Python package: when the Booster object is pickled with the built-in pickle module.
  • R package: when the xgb.Booster object is persisted with the built-in functions saveRDS or save.

Other language bindings are still working in progress. :::tips 🔖 Note ::: :::info The old binary format doesn’t distinguish difference between model and raw memory serialisation format, it’s a mix of everything, which is part of the reason why we want to replace it with a more robust(强大的) serialisation method. JVM Package has its own memory based serialisation methods. :::

To enable JSON format support for model IO (saving only the trees and objective), provide a filename with .json as file extension:
Python

  1. bst.save_model('model_file_name.json')

R

  1. xgb.save(bst, 'model_file_name.json')

While for memory snapshot, JSON is the default starting with xgboost 1.3.

1、A note on backward compatibility of models and memory snapshots

We guarantee backward compatibility(向后兼容) for models but not for memory snapshots.

Models (trees and objective) use a stable representation, so that models produced in earlier versions of XGBoost are accessible in later versions of XGBoost. If you’d like to store or archive your model for long-term storage, use save_model (Python) and xgb.save (R).

On the other hand(另一方面), memory snapshot (serialisation) captures many stuff(无关紧要的东西) internal to XGBoost, and its format is not stable and is subject to(处于…中) frequent changes. Therefore, memory snapshot is suitable for checkpointing only, where you persist the complete snapshot of the training configurations so that you can recover robustly(粗鲁地) from possible failures and resume the training process. Loading memory snapshot generated by an earlier version of XGBoost may result in errors or undefined behaviors. If a model is persisted with pickle.dump (Python) or saveRDS (R), then the model may not be accessible in later versions of XGBoost.

2、Custom objective and metric

XGBoost accepts user provided objective and metric functions as an extension. These functions are not saved in model file as they are language dependent features. With Python, user can pickle the model to include these functions in saved binary. One drawback(缺点) is, the output from pickle is not a stable serialization format and doesn’t work on different Python version nor XGBoost version, not to mention different language environments. Another way to workaround(应变方法) this limitation is to provide these functions again after the model is loaded. If the customized function is useful, please consider making a PR for implementing(实现) it inside XGBoost, this way we can have your functions working with different language bindings.

3、Loading pickled file from different version of XGBoost

As noted, pickled model is neither portable nor stable, but in some cases the pickled models are valuable(很重要的). One way to restore(修复) it in the future is to load it back with that specific version of Python and XGBoost, export the model by calling save_model. To help easing(使…容易些) the mitigation(减轻), we created a simple script for converting pickled XGBoost 0.90 Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script suits simple use cases, and it’s advised not to use pickle when stability is needed. It’s located in xgboost/doc/python with the name convert_090to100.py. See comments in the script for more details.

A similar procedure may be used to recover the model persisted in an old RDS file. In R, you are able to install an older version of XGBoost using the remotes package:

  1. library(remotes)
  2. remotes::install_version("xgboost", "0.90.0.1") # Install version 0.90.0.1

Once the desired version is installed, you can load the RDS file with readRDS and recover the xgb.Booster object. Then call xgb.save to export the model using the stable representation(表现). Now you should be able to use the model in the latest version of XGBoost.

4、Saving and Loading the internal parameters configuration

XGBoost’s C API, Python API and R API support saving and loading the internal configuration directly as a JSON string. In Python package:

  1. bst = xgboost.train(...)
  2. config = bst.save_config()
  3. print(config)

or in R:

  1. config <- xgb.config(bst)
  2. print(config)

Will print out something similiar to (not actual output as it’s too long for demonstration(示范)):

  1. "Learner": {
  2. "generic_parameter": {
  3. "gpu_id": "0",
  4. "gpu_page_size": "0",
  5. "n_jobs": "0",
  6. "random_state": "0",
  7. "seed": "0",
  8. "seed_per_iteration": "0"
  9. },
  10. "gradient_booster": {
  11. "gbtree_train_param": {
  12. "num_parallel_tree": "1",
  13. "predictor": "gpu_predictor",
  14. "process_type": "default",
  15. "tree_method": "gpu_hist",
  16. "updater": "grow_gpu_hist",
  17. "updater_seq": "grow_gpu_hist"
  18. },
  19. "name": "gbtree",
  20. "updater": {
  21. "grow_gpu_hist": {
  22. "gpu_hist_train_param": {
  23. "debug_synchronize": "0",
  24. "gpu_batch_nrows": "0",
  25. "single_precision_histogram": "0"
  26. },
  27. "train_param": {
  28. "alpha": "0",
  29. "cache_opt": "1",
  30. "colsample_bylevel": "1",
  31. "colsample_bynode": "1",
  32. "colsample_bytree": "1",
  33. "default_direction": "learn",
  34. ...
  35. "subsample": "1"
  36. }
  37. }
  38. }
  39. },
  40. "learner_train_param": {
  41. "booster": "gbtree",
  42. "disable_default_eval_metric": "0",
  43. "dsplit": "auto",
  44. "objective": "reg:squarederror"
  45. },
  46. "metrics": [],
  47. "objective": {
  48. "name": "reg:squarederror",
  49. "reg_loss_param": {
  50. "scale_pos_weight": "1"
  51. }
  52. }
  53. },
  54. "version": [1, 0, 0]
  55. }

You can load it back to the model generated by same version of XGBoost by:

  1. bst.load_config(config)

This way users can study the internal representation more closely. Please note that some JSON generators make use of locale(现场) dependent(依赖的) floating point(浮点) serialization methods, which is not supported by XGBoost.

5、Difference between saving model and dumping model

XGBoost has a function called dump_model in Booster object, which lets you to export the model in a readable format like text, json or dot (graphviz). The primary use case for it is for model interpretation or visualization, and is not supposed to be loaded back to XGBoost. The JSON version has a schema. See next section for more info.

not supposed to:不应该

6、JSON Schema

Another important feature of JSON format is a documented Schema, based on which one can easily reuse the output model from XGBoost. Here is the initial draft(草稿) of JSON schema for the output model (not serialization, which will not be stable as noted above). It’s subject to change due to the beta status. For an example of parsing XGBoost tree model, see /demo/json-model. Please notice the “weight_drop” field used in “dart” booster. XGBoost does not scale tree leaf directly, instead it saves the weights as a separated array.

  1. {
  2. "$schema": "http://json-schema.org/draft-07/schema#",
  3. "definitions": {
  4. "gbtree": {
  5. "type": "object",
  6. "properties": {
  7. "name": {
  8. "const": "gbtree"
  9. },
  10. "model": {
  11. "type": "object",
  12. "properties": {
  13. "gbtree_model_param": {
  14. "$ref": "#/definitions/gbtree_model_param"
  15. },
  16. "trees": {
  17. "type": "array",
  18. "items": {
  19. "type": "object",
  20. "properties": {
  21. "tree_param": {
  22. "type": "object",
  23. "properties": {
  24. "num_nodes": {
  25. "type": "string"
  26. },
  27. "size_leaf_vector": {
  28. "type": "string"
  29. },
  30. "num_feature": {
  31. "type": "string"
  32. }
  33. },
  34. "required": [
  35. "num_nodes",
  36. "num_feature",
  37. "size_leaf_vector"
  38. ]
  39. },
  40. "id": {
  41. "type": "integer"
  42. },
  43. "loss_changes": {
  44. "type": "array",
  45. "items": {
  46. "type": "number"
  47. }
  48. },
  49. "sum_hessian": {
  50. "type": "array",
  51. "items": {
  52. "type": "number"
  53. }
  54. },
  55. "base_weights": {
  56. "type": "array",
  57. "items": {
  58. "type": "number"
  59. }
  60. },
  61. "left_children": {
  62. "type": "array",
  63. "items": {
  64. "type": "integer"
  65. }
  66. },
  67. "right_children": {
  68. "type": "array",
  69. "items": {
  70. "type": "integer"
  71. }
  72. },
  73. "parents": {
  74. "type": "array",
  75. "items": {
  76. "type": "integer"
  77. }
  78. },
  79. "split_indices": {
  80. "type": "array",
  81. "items": {
  82. "type": "integer"
  83. }
  84. },
  85. "split_conditions": {
  86. "type": "array",
  87. "items": {
  88. "type": "number"
  89. }
  90. },
  91. "split_type": {
  92. "type": "array",
  93. "items": {
  94. "type": "integer"
  95. }
  96. },
  97. "default_left": {
  98. "type": "array",
  99. "items": {
  100. "type": "boolean"
  101. }
  102. },
  103. "categories": {
  104. "type": "array",
  105. "items": {
  106. "type": "integer"
  107. }
  108. },
  109. "categories_nodes": {
  110. "type": "array",
  111. "items": {
  112. "type": "integer"
  113. }
  114. },
  115. "categories_segments": {
  116. "type": "array",
  117. "items": {
  118. "type": "integer"
  119. }
  120. },
  121. "categorical_sizes": {
  122. "type": "array",
  123. "items": {
  124. "type": "integer"
  125. }
  126. }
  127. },
  128. "required": [
  129. "tree_param",
  130. "loss_changes",
  131. "sum_hessian",
  132. "base_weights",
  133. "left_children",
  134. "right_children",
  135. "parents",
  136. "split_indices",
  137. "split_conditions",
  138. "default_left",
  139. "categories",
  140. "categories_nodes",
  141. "categories_segments",
  142. "categories_sizes"
  143. ]
  144. }
  145. },
  146. "tree_info": {
  147. "type": "array",
  148. "items": {
  149. "type": "integer"
  150. }
  151. }
  152. },
  153. "required": [
  154. "gbtree_model_param",
  155. "trees",
  156. "tree_info"
  157. ]
  158. }
  159. },
  160. "required": [
  161. "name",
  162. "model"
  163. ]
  164. },
  165. "gbtree_model_param": {
  166. "type": "object",
  167. "properties": {
  168. "num_trees": {
  169. "type": "string"
  170. },
  171. "size_leaf_vector": {
  172. "type": "string"
  173. }
  174. },
  175. "required": [
  176. "num_trees",
  177. "size_leaf_vector"
  178. ]
  179. },
  180. "tree_param": {
  181. "type": "object",
  182. "properties": {
  183. "num_nodes": {
  184. "type": "string"
  185. },
  186. "size_leaf_vector": {
  187. "type": "string"
  188. },
  189. "num_feature": {
  190. "type": "string"
  191. }
  192. },
  193. "required": [
  194. "num_nodes",
  195. "num_feature",
  196. "size_leaf_vector"
  197. ]
  198. },
  199. "reg_loss_param": {
  200. "type": "object",
  201. "properties": {
  202. "scale_pos_weight": {
  203. "type": "string"
  204. }
  205. }
  206. },
  207. "aft_loss_param": {
  208. "type": "object",
  209. "properties": {
  210. "aft_loss_distribution": {
  211. "type": "string"
  212. },
  213. "aft_loss_distribution_scale": {
  214. "type": "string"
  215. }
  216. }
  217. },
  218. "softmax_multiclass_param": {
  219. "type": "object",
  220. "properties": {
  221. "num_class": { "type": "string" }
  222. }
  223. },
  224. "lambda_rank_param": {
  225. "type": "object",
  226. "properties": {
  227. "num_pairsample": { "type": "string" },
  228. "fix_list_weight": { "type": "string" }
  229. }
  230. }
  231. },
  232. "type": "object",
  233. "properties": {
  234. "version": {
  235. "type": "array",
  236. "items": [
  237. {
  238. "type": "number",
  239. "const": 1
  240. },
  241. {
  242. "type": "number",
  243. "minimum": 0
  244. },
  245. {
  246. "type": "number",
  247. "minimum": 0
  248. }
  249. ],
  250. "minItems": 3,
  251. "maxItems": 3
  252. },
  253. "learner": {
  254. "type": "object",
  255. "properties": {
  256. "feature_names": {
  257. "type": "array",
  258. "items": {
  259. "type": "string"
  260. }
  261. },
  262. "feature_types": {
  263. "type": "array",
  264. "items": {
  265. "type": "string"
  266. }
  267. },
  268. "gradient_booster": {
  269. "oneOf": [
  270. {
  271. "$ref": "#/definitions/gbtree"
  272. },
  273. {
  274. "type": "object",
  275. "properties": {
  276. "name": { "const": "gblinear" },
  277. "model": {
  278. "type": "object",
  279. "properties": {
  280. "weights": {
  281. "type": "array",
  282. "items": {
  283. "type": "number"
  284. }
  285. }
  286. }
  287. }
  288. }
  289. },
  290. {
  291. "type": "object",
  292. "properties": {
  293. "name": { "const": "dart" },
  294. "gbtree": {
  295. "$ref": "#/definitions/gbtree"
  296. },
  297. "weight_drop": {
  298. "type": "array",
  299. "items": {
  300. "type": "number"
  301. }
  302. }
  303. },
  304. "required": [
  305. "name",
  306. "gbtree",
  307. "weight_drop"
  308. ]
  309. }
  310. ]
  311. },
  312. "objective": {
  313. "oneOf": [
  314. {
  315. "type": "object",
  316. "properties": {
  317. "name": { "const": "reg:squarederror" },
  318. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  319. },
  320. "required": [
  321. "name",
  322. "reg_loss_param"
  323. ]
  324. },
  325. {
  326. "type": "object",
  327. "properties": {
  328. "name": { "const": "reg:pseudohubererror" },
  329. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  330. },
  331. "required": [
  332. "name",
  333. "reg_loss_param"
  334. ]
  335. },
  336. {
  337. "type": "object",
  338. "properties": {
  339. "name": { "const": "reg:squaredlogerror" },
  340. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  341. },
  342. "required": [
  343. "name",
  344. "reg_loss_param"
  345. ]
  346. },
  347. {
  348. "type": "object",
  349. "properties": {
  350. "name": { "const": "reg:linear" },
  351. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  352. },
  353. "required": [
  354. "name",
  355. "reg_loss_param"
  356. ]
  357. },
  358. {
  359. "type": "object",
  360. "properties": {
  361. "name": { "const": "reg:logistic" },
  362. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  363. },
  364. "required": [
  365. "name",
  366. "reg_loss_param"
  367. ]
  368. },
  369. {
  370. "type": "object",
  371. "properties": {
  372. "name": { "const": "binary:logistic" },
  373. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  374. },
  375. "required": [
  376. "name",
  377. "reg_loss_param"
  378. ]
  379. },
  380. {
  381. "type": "object",
  382. "properties": {
  383. "name": { "const": "binary:logitraw" },
  384. "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
  385. },
  386. "required": [
  387. "name",
  388. "reg_loss_param"
  389. ]
  390. },
  391. {
  392. "type": "object",
  393. "properties": {
  394. "name": { "const": "count:poisson" },
  395. "poisson_regression_param": {
  396. "type": "object",
  397. "properties": {
  398. "max_delta_step": { "type": "string" }
  399. }
  400. }
  401. },
  402. "required": [
  403. "name",
  404. "poisson_regression_param"
  405. ]
  406. },
  407. {
  408. "type": "object",
  409. "properties": {
  410. "name": { "const": "reg:tweedie" },
  411. "tweedie_regression_param": {
  412. "type": "object",
  413. "properties": {
  414. "tweedie_variance_power": { "type": "string" }
  415. }
  416. }
  417. },
  418. "required": [
  419. "name",
  420. "tweedie_regression_param"
  421. ]
  422. },
  423. {
  424. "type": "object",
  425. "properties": {
  426. "name": { "const": "survival:cox" }
  427. },
  428. "required": [ "name" ]
  429. },
  430. {
  431. "type": "object",
  432. "properties": {
  433. "name": { "const": "reg:gamma" }
  434. },
  435. "required": [ "name" ]
  436. },
  437. {
  438. "type": "object",
  439. "properties": {
  440. "name": { "const": "multi:softprob" },
  441. "softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
  442. },
  443. "required": [
  444. "name",
  445. "softmax_multiclass_param"
  446. ]
  447. },
  448. {
  449. "type": "object",
  450. "properties": {
  451. "name": { "const": "multi:softmax" },
  452. "softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
  453. },
  454. "required": [
  455. "name",
  456. "softmax_multiclass_param"
  457. ]
  458. },
  459. {
  460. "type": "object",
  461. "properties": {
  462. "name": { "const": "rank:pairwise" },
  463. "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
  464. },
  465. "required": [
  466. "name",
  467. "lambda_rank_param"
  468. ]
  469. },
  470. {
  471. "type": "object",
  472. "properties": {
  473. "name": { "const": "rank:ndcg" },
  474. "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
  475. },
  476. "required": [
  477. "name",
  478. "lambda_rank_param"
  479. ]
  480. },
  481. {
  482. "type": "object",
  483. "properties": {
  484. "name": { "const": "rank:map" },
  485. "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
  486. },
  487. "required": [
  488. "name",
  489. "lambda_rank_param"
  490. ]
  491. },
  492. {
  493. "type": "object",
  494. "properties": {
  495. "name": {"const": "survival:aft"},
  496. "aft_loss_param": { "$ref": "#/definitions/aft_loss_param"}
  497. }
  498. },
  499. {
  500. "type": "object",
  501. "properties": {
  502. "name": {"const": "binary:hinge"}
  503. }
  504. }
  505. ]
  506. },
  507. "learner_model_param": {
  508. "type": "object",
  509. "properties": {
  510. "base_score": { "type": "string" },
  511. "num_class": { "type": "string" },
  512. "num_feature": { "type": "string" }
  513. }
  514. }
  515. },
  516. "required": [
  517. "gradient_booster",
  518. "objective"
  519. ]
  520. }
  521. },
  522. "required": [
  523. "version",
  524. "learner"
  525. ]
  526. }

7、Future Plans

Right now using the JSON format incurs(引起) longer serialisation time, we have been working on optimizing the JSON implementation to close the gap between binary format and JSON format.