DAG 编排规划

这个例子是通过AWEL DAG编排来展示如何通过大模型来实现用户输入自然语言问题,然后查找数据库相关表(SchemaLinking)再生成对应的 SQL,最后将 SQL 执行获取到数据库中的数据,并绘制成图片,整个编排有如下几步:

  • 发起 Http 请求
  • 处理请求内容
  • 大模型推理得到 Schema 信息
  • 大模型推理得到 SQL 语句
  • 查询 SQL 结果
  • 绘制图片

AWEL SchemaLinking入门 - 图1

同样的,MapOperatorJoinOperator算子是DB-GPT内置算子,所以可以直接引用来使用。

  1. import os
  2. from typing import Any, Dict, Optional
  3. from pandas import DataFrame
  4. from pydantic import BaseModel, Field
  5. from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
  6. from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
  7. from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator

自定义算子

同样的,我们需要自定义一个处理用户请求来构造模型输入参数的算子。 首先定义用户请求参数,参数里面需要传一个内容,用户输入信息 user_query

  1. class TriggerReqBody(BaseModel):
  2. query: str = Field(..., description="User query")

根据请求参数来构造模型推理参数,自定义一个RequestHandleOperator的算子,此算子继承了MapOperator算子,通过重写map方法,即可实现参数的构造。

  1. class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
  2. def __init__(self, **kwargs):
  3. super().__init__(**kwargs)
  4. async def map(self, input_value: TriggerReqBody) -> Dict:
  5. params = {
  6. "query": input_value.query,
  7. }
  8. print(f"Receive input value: {input_value}")
  9. return params

然后通过内置的<font style="background-color:#E7E9E8;">MapOperator</font>解析得到query字符串。

  1. query_operator = MapOperator(lambda request: request["query"])

得到query字符串后,通过自定义的SchemaLinkingOperator,重写map方法实现参数的构造,最终可以到和 query 最相关的数据表信息。具体的逻辑实现在类SchemaLinkingschema_linking_with_llm方法中。

  1. class SchemaLinkingOperator(MapOperator[Any, Any]):
  2. """The Schema Linking Operator."""
  3. def __init__(
  4. self,
  5. top_k: int = 5,
  6. connection: Optional[RDBMSDatabase] = None,
  7. llm: Optional[LLMClient] = None,
  8. model_name: Optional[str] = None,
  9. vector_store_connector: Optional[VectorStoreConnector] = None,
  10. **kwargs
  11. ):
  12. """Init the schema linking operator
  13. Args:
  14. connection (RDBMSDatabase): The connection.
  15. llm (Optional[LLMClient]): base llm
  16. """
  17. super().__init__(**kwargs)
  18. self._schema_linking = SchemaLinking(
  19. top_k=top_k,
  20. connection=connection,
  21. llm=llm,
  22. model_name=model_name,
  23. vector_store_connector=vector_store_connector,
  24. )
  25. async def map(self, query: str) -> str:
  26. """retrieve table schemas.
  27. Args:
  28. query (str): query.
  29. Return:
  30. str: schema info
  31. """
  32. return str(await self._schema_linking.schema_linking_with_llm(query))

_schema_linking_with_llm 方法中,先调用self.schema_linking方法获取数据库中所有的 表结构信息,这里使用的是 RAG中切分文档的返回格式 chunks_content,然后把 query 和得到的所有 schema 拼接为 schema_prompt,最后通过 llm生成最相关的 schema 信息。

  1. def _schema_linking(self, query: str) -> List:
  2. """get all db schema info"""
  3. table_summaries = _parse_db_summary(self._connection)
  4. chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
  5. chunks_content = [chunk.content for chunk in chunks]
  6. return chunks_content
  7. def _schema_linking_with_vector_db(self, query: str) -> List:
  8. queries = [query]
  9. candidates = [
  10. self._vector_store_connector.similar_search(query, self._top_k)
  11. for query in queries
  12. ]
  13. candidates = reduce(lambda x, y: x + y, candidates)
  14. return candidates
  15. async def _schema_linking_with_llm(self, query: str) -> List:
  16. chunks_content = self.schema_linking(query)
  17. schema_prompt = INSTRUCTION.format(
  18. str(chunks_content) + INPUT_PROMPT.format(query)
  19. )
  20. messages = [
  21. ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
  22. ]
  23. request = ModelRequest(model=self._model_name, messages=messages)
  24. tasks = [self._llm.generate(request)]
  25. # get accurate schem info by llm
  26. schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
  27. schema_text = schema[0].text
  28. return schema_text

得到和query字符串最相关的schema信息后,使用内置的JoinOperator算子,把queryschema拼接为 prompt

  1. def _prompt_join_fn(query: str, chunks: str) -> str:
  2. prompt = INSTRUCTION.format(chunks + INPUT_PROMPT.format(query))
  3. return prompt
  4. prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)

得到prompt后,通过自定义的 SqlGenOperator算子,调用 llm.generate方法生成 sql

  1. class SqlGenOperator(MapOperator[Any, Any]):
  2. """The Sql Generation Operator."""
  3. def __init__(self, llm: Optional[LLMClient], model_name: str, **kwargs):
  4. """Init the sql generation operator
  5. Args:
  6. llm (Optional[LLMClient]): base llm
  7. """
  8. super().__init__(**kwargs)
  9. self._llm = llm
  10. self._model_name = model_name
  11. async def map(self, prompt_with_query_and_schema: str) -> str:
  12. """generate sql by llm.
  13. Args:
  14. prompt_with_query_and_schema (str): prompt
  15. Return:
  16. str: sql
  17. """
  18. messages = [
  19. ModelMessage(
  20. role=ModelMessageRoleType.SYSTEM, content=prompt_with_query_and_schema
  21. )
  22. ]
  23. request = ModelRequest(model=self._model_name, messages=messages)
  24. tasks = [self._llm.generate(request)]
  25. output = await run_async_tasks(tasks=tasks, concurrency_limit=1)
  26. sql = output[0].text
  27. return sql

得到 sql后,通过自定义的SqlExecOperator算子,调用run_to_df函数实现执行 sql,返回从数据库中的执行结果。

  1. class SqlExecOperator(MapOperator[Any, Any]):
  2. """The Sql Execution Operator."""
  3. def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs):
  4. """
  5. Args:
  6. connection (Optional[RDBMSDatabase]): RDBMSDatabase connection
  7. """
  8. super().__init__(**kwargs)
  9. self._connection = connection
  10. def map(self, sql: str) -> DataFrame:
  11. """retrieve table schemas.
  12. Args:
  13. sql (str): query.
  14. Return:
  15. str: sql execution
  16. """
  17. dataframe = self._connection.run_to_df(command=sql, fetch="all")
  18. print(f"sql data is \n{dataframe}")
  19. return dataframe

得到 sql查询结果后,通过自定义的 ChartDrawOperator算子,使用 matplotlib.pyplot绘制图片。

  1. class ChartDrawOperator(MapOperator[Any, Any]):
  2. """The Chart Draw Operator."""
  3. def __init__(self, **kwargs):
  4. """
  5. Args:
  6. connection (RDBMSDatabase): The connection.
  7. """
  8. super().__init__(**kwargs)
  9. def map(self, df: DataFrame) -> str:
  10. """get sql result in db and draw.
  11. Args:
  12. sql (str): str.
  13. """
  14. import matplotlib.pyplot as plt
  15. category_column = df.columns[0]
  16. count_column = df.columns[1]
  17. plt.figure(figsize=(8, 4))
  18. plt.bar(df[category_column], df[count_column])
  19. plt.xlabel(category_column)
  20. plt.ylabel(count_column)
  21. plt.show()
  22. return str(df)

DAG 编排

编写好算子之后,下一步即可进行算子的编排,通过AWEL DAG进行算子编排。

trigger >> request_handle_task >> query_operator >> prompt_join_operator
(
    trigger
    >> request_handle_task
    >> query_operator
    >> retriever_task
    >> prompt_join_operator
)
prompt_join_operator >> sql_gen_operator >> sql_exec_operator >> draw_chart_operator

测试验证

  1. 安装 openai依赖:pip install "db-gpt[openai]"
  2. 设置 openai环境变量 :export OPENAI_API_KEY={your_openai_key}export OPENAI_API_BASE={your_openai_base}
  3. 运行代码:python examples/awel/simple_nl_schema_sql_chart_example.py
  4. curl测试代码:
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/schema_linking' \
--header 'Content-Type: application/json' \
--data '{"query": "Statistics of user age in the user table are based on three categories: age is less than 10, age is greater than or equal to 10 and less than or equal to 20, and age is greater than 20. The first column of the statistical results is different ages, and the second column is count."}'

:::danger ⚠️注意: 测试端口跟启动端口保持一致

:::