Text2SQL相关微调的代码,我们拆分到了DB-GPT-Hub子项目当中,也可以直接查看源码

微调流程

Text2SQL微调主要包含以下流程:

环境搭建

我们推荐使用conda虚拟环境 来搭建Text2SQL微调的环境

  1. git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
  2. cd DB-GPT-Hub
  3. conda create -n dbgpt_hub python=3.10
  4. conda activate dbgpt_hub
  5. conda install -c conda-forge poetry>=1.4.0
  6. poetry install

当前项目支持多种基座模型,可以按需下载。本教程中我们以<font style="color:rgba(0, 0, 0, 0.9);">CodeLlama-13b-Instruct-hf</font>为基座模型,模型可以从HuggingFace魔搭 等平台下载。以HuggingFace为例, 下载命令为:

  1. cd Your_model_dir
  2. git lfs install
  3. git clone git@hf.co:codellama/CodeLlama-13b-Instruct-hf

数据处理

数据收集

本教程案例数据主要以 Spider 数据集为示例 :

  • 简介:Spider 数据集是一个跨域的复杂 text2sql 数据集,包含了自然语言问句和分布在 200 个独立数据库中的多条 SQL,内容覆盖了 138 个不同的领域。
  • 下载:下载数据集到项目目录, 即位于<font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/data/spider</font>中。

数据处理

项目使用的是信息匹配生成法进行数据准备,即结合表信息的 SQL + Repository 生成方式,这种方式结合了数据表信息,能够更好地理解数据表的结构和关系,适用于生成符合需求的 SQL 语句。 项目已经将相关处理代码封装在对应脚本中,可以直接一键运行脚本命令,在 <font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/data/</font>目录中将得到生成的训练集 <font style="color:rgba(0, 0, 0, 0.9);">example_text2sql_train.json</font><font style="color:rgba(0, 0, 0, 0.9);">example_text2sql_dev.json</font> python ## 生成 train 数据 和 dev(eval) 数据, sh dbgpt_hub/scripts/gen_train_eval_data.sh 其中训练集中为 8659 条,评估集为 1034 条。生成的训练集中数据格式形如:
  1. {
  2. "db_id": "department_management",
  3. "instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n\n",
  4. "input": "###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:",
  5. "output": "SELECT count(*) FROM head WHERE age > 56",
  6. "history": []
  7. }
<font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/data/dataset_info.json</font> 中配置训练的数据文件,json文件中对应的 key 的值默认为 <font style="color:rgba(0, 0, 0, 0.9);">example_text2sql</font>,此值即在后续训练脚本 <font style="color:rgba(0, 0, 0, 0.9);">train_sft</font> 中参数<font style="color:rgba(0, 0, 0, 0.9);"> --dataset</font> 需要传入的值, json中的file_name 的值为训练集的文件名字。

数据处理代码逻辑

数据处理的核心代码主要在 <font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/data_process/sql_data_process.py</font> 中,核心处理 class 是 <font style="color:rgba(0, 0, 0, 0.9);">ProcessSqlData()</font>,核心处理函数是 <font style="color:rgba(0, 0, 0, 0.9);">decode_json_file()</font>

<font style="color:rgba(0, 0, 0, 0.9);">decode_json_file() </font>首先将 Spider 数据中的 table 信息处理成为字典格式,key 和 value 分别是 db_id 和该 db_id 对应的 table、columns 信息处理成所需的格式,例如:

  1. {
  2. "department_management": department_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.
  3. }
然后将上述文本填充于 config 文件中 INSTRUCTION_PROMPT 的 {} 部分,形成最终的 instruction, INSTRUCTION_PROMPT 如下所示:
  1. INSTRUCTION_PROMPT = "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n ##Instruction:\n{}\n"
最后将训练集和验证集中每一个 db_id 对应的 question 和 query 处理成模型 SFT 训练所需的格式,即上面数据处理代码执行部分所示的数据格式。

:::success 说明: 如果你想自己收集更多数据进行训练,可以利用本项目相关代码参照如上逻辑进行处理。

:::

SFT训练

为了简便起见,本复现教程以 LoRA 微调直接运行作为示例,但项目微调不仅能支持 LoRA 还支持 QLoRA 以及 deepspeed 加速。训练脚本 <font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/scripts/train_sft.sh</font> 详细参数如下所示:
  1. CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
  2. --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \
  3. --do_train \
  4. --dataset example_text2sql_train \
  5. --max_source_length 2048 \
  6. --max_target_length 512 \
  7. --finetuning_type lora \
  8. --lora_target q_proj,v_proj \
  9. --template llama2 \
  10. --lora_rank 64 \
  11. --lora_alpha 32 \
  12. --output_dir dbgpt_hub/output/adapter/code_llama-13b-2048_epoch8_lora \
  13. --overwrite_cache \
  14. --overwrite_output_dir \
  15. --per_device_train_batch_size 1 \
  16. --gradient_accumulation_steps 16 \
  17. --lr_scheduler_type cosine_with_restarts \
  18. --logging_steps 50 \
  19. --save_steps 2000 \
  20. --learning_rate 2e-4 \
  21. --num_train_epochs 8 \
  22. --plot_loss \
  23. --bf16

:::color1 train_sft.sh 中关键参数与含义介绍:

  • model_name_or_path :所用 LLM 模型的路径。
  • dataset :取值为训练数据集的配置名字,对应在 dbgpt_hub/data/dataset_info.json 中外层 key 值,如 example_text2sql。
  • max_source_length :输入模型的文本长度,本教程的效果参数为 2048,为多次实验与分析后的最佳长度。
  • max_target_length :输出模型的 sql 内容长度,设置为 512。
  • template:项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。
  • lora_target :LoRA 微调时的网络参数更改部分。
  • finetuning_type : 微调类型,取值为 [ ptuning、lora、freeze、full ] 等。
  • lora_rank : LoRA 微调中的秩大小。
  • loran_alpha: LoRA 微调中的缩放系数。
  • output_dir :SFT 微调时 Peft 模块输出的路径,默认设置在 dbgpt_hub/output/adapter/路径下 。
  • per_device_train_batch_size :每张 gpu 上训练样本的批次,如果计算资源支持,可以设置为更大,默认为 1。
  • gradient_accumulation_steps :梯度更新的累计steps值。
  • lr_scheduler_type :学习率类型。
  • logging_steps :日志保存的 steps 间隔。
  • save_steps :模型保存的 ckpt 的 steps 大小值。
  • num_train_epochs :训练数据的 epoch 数。
  • learning_rate : 学习率,推荐的学习率为 2e-4。

:::

如果想基于 QLoRA 训练,可以在脚本中增加参数 quantization_bit 表示是否量化,取值为 [ 4 或者 8 ],开启量化。 对于想微调不同的 LLM,不同模型对应的关键参数 lora_target 和 template,可以参照项目的 README.md 中相关内容进行更改。

合并权重

模型预测

模型训练结束后,对训练好的模型进行预测,可以直接运行项目脚本目录中的predict_sft.sh
预测运行命令:

  1. sh ./dbgpt_hub/scripts/predict_sft.sh
项目目录下<font style="color:rgba(0, 0, 0, 0.9);">./dbgpt_hub/ </font>下的<font style="color:rgba(0, 0, 0, 0.9);"> output/pred/</font>,此文件路径为关于模型预测结果默认输出的位置(如果没有则需建立)。本教程 <font style="color:rgba(0, 0, 0, 0.9);">predict_sft.sh</font> 中的详细参数如下
  1. echo " predict Start time: $(date)"
  2. ## predict
  3. CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/predict/predict.py \
  4. --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \
  5. --template llama2 \
  6. --finetuning_type lora \
  7. --checkpoint_dir Your_last_peft_checkpoint-4000 \
  8. --predicted_out_filename Your_model_pred.sql
  9. echo "predict End time: $(date)"
其中参数<font style="color:rgba(0, 0, 0, 0.9);"> </font>**<font style="color:rgba(0, 0, 0, 0.9);">--predicted_out_filename</font>** 的值为模型预测的结果文件名,结果在 <font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/output/pred</font> 目录下可以找到。

效果评估

对于模型在数据集上的效果评估,默认为在 spider 数据集上。运行以下命令:
  1. python dbgpt_hub/eval/evaluation.py --plug_value --input Your_model_pred.sql

由于大模型生成的结果具有一定的随机性,和 temperature 等参数密切相关(可以在 /dbgpt_hub/configs/model_args.py 中的 GeneratingArguments 中进行调整)。在项目默认情况下,我们多次评估的效果,执行准确率均在 0.789 及以上。部分实验和评测结果我们已经放在项目 docs/eval_llm_result.md中,仅供参考。

DB-GPT-Hub 基于 CodeLlama-13b-Instruct-hf大模型用 LoRA 在 Spider 的训练集上微调后的权重文件已经放出,目前在 spider 的评估集上实现了约为 0.789 的执行准确率,权重文件CodeLlama-13b-sql-loraHuggingFace 上可以找到。

附录说明

本文实验环境为基于一台带有 A100(40G) 的显卡服务器,总训练时长 12h 左右。如果你的机器资源不够,可以优先考虑缩小参数 gradient_accumulation_steps 的取值,另外可以考虑用 QLoRA 的方式微调(训练脚本 <font style="color:rgba(0, 0, 0, 0.9);">dbgpt_hub/scripts/train_sft.sh</font>中增加 --quantization_bit 4),从我们的经验看,QLoRA 在 8 个 epoch 时和 LoRA 微调的结果相差不大。