DB-GPT-Hub 项目发布了pip包,用来降低Text2SQL训练的门槛, 除了通过仓库中提供的脚本的方式进行微调之外,还可以使用我们提供的Python包进行微调。

安装

  1. pip install dbgpt_hub

查看基线

  1. from dbgpt_hub.baseline import show_scores
  2. show_scores()

使用dbgpt_hub包 - 图1

微调

  1. from dbgpt_hub.data_process import preprocess_sft_data
  2. from dbgpt_hub.train import start_sft
  3. from dbgpt_hub.predict import start_predict
  4. from dbgpt_hub.eval import start_evaluate
  5. data_folder = "dbgpt_hub/data"
  6. data_info = [
  7. {
  8. "data_source": "spider",
  9. "train_file": ["train_spider.json", "train_others.json"],
  10. "dev_file": ["dev.json"],
  11. "tables_file": "tables.json",
  12. "db_id_name": "db_id",
  13. "is_multiple_turn": False,
  14. "train_output": "spider_train.json",
  15. "dev_output": "spider_dev.json",
  16. }
  17. ]
  18. train_args = {
  19. "model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
  20. "do_train": True,
  21. "dataset": "example_text2sql_train",
  22. "max_source_length": 2048,
  23. "max_target_length": 512,
  24. "finetuning_type": "lora",
  25. "lora_target": "q_proj,v_proj",
  26. "template": "llama2",
  27. "lora_rank": 64,
  28. "lora_alpha": 32,
  29. "output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
  30. "overwrite_cache": True,
  31. "overwrite_output_dir": True,
  32. "per_device_train_batch_size": 1,
  33. "gradient_accumulation_steps": 16,
  34. "lr_scheduler_type": "cosine_with_restarts",
  35. "logging_steps": 50,
  36. "save_steps": 2000,
  37. "learning_rate": 2e-4,
  38. "num_train_epochs": 8,
  39. "plot_loss": True,
  40. "bf16": True,
  41. }
  42. predict_args = {
  43. "model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
  44. "template": "llama2",
  45. "finetuning_type": "lora",
  46. "checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
  47. "predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
  48. "predict_out_dir": "dbgpt_hub/output/",
  49. "predicted_out_filename": "pred_sql.sql",
  50. }
  51. evaluate_args = {
  52. "input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
  53. "gold": "./dbgpt_hub/data/eval_data/gold.txt",
  54. "gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
  55. "db": "./dbgpt_hub/data/spider/database",
  56. "table": "./dbgpt_hub/data/eval_data/tables.json",
  57. "table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
  58. "etype": "exec",
  59. "plug_value": True,
  60. "keep_distict": False,
  61. "progress_bar_for_each_datapoint": False,
  62. "natsql": False,
  63. }
  64. preprocess_sft_data(
  65. data_folder = data_folder,
  66. data_info = data_info
  67. )
  68. start_sft(train_args)
  69. start_predict(predict_args)
  70. start_evaluate(evaluate_args)