BranchOperator 算子即分支算子,用于根据输入数据决策下游分支链路。 比如, 如果你有两个分支,你可以根据输入决定运行哪条路径。

有两种方式可以使用BranchOperator

通过一个分支映射来构建BranchOperator

将分支函数和任务名称的字典传递给BranchOperator构造函数

  1. from dbgpt.core.awel import DAG, BranchOperator, MapOperator
  2. def branch_even(x: int) -> bool:
  3. return x % 2 == 0
  4. def branch_odd(x: int) -> bool:
  5. return not branch_even(x)
  6. branch_mapping = {
  7. branch_even: "even_task",
  8. branch_odd: "odd_task"
  9. }
  10. with DAG("awel_branch_operator") as dag:
  11. task = BranchOperator(branches=branch_mapping)
  12. even_task = MapOperator(
  13. task_name="even_task",
  14. map_function=lambda x: print(f"{x} is even")
  15. )
  16. odd_task = MapOperator(
  17. task_name="odd_task",
  18. map_function=lambda x: print(f"{x} is odd")
  19. )

在上述例子中,<font style="color:rgb(28, 30, 33);background-color:rgb(246, 247, 248);">BranchOperator</font>算子有两个子任务,even_taskodd_task. BranchOperator 将根据输入数据来决定该运行哪个算子。因此我们将分支函数和任务名称的字典传递给BranchOperator函数来定义分支映射, 字典中的键值是分支函数,值是任务名称,当运行分支任务时,所有分支函数都会执行时,如果分支函数返回True,则任务将被执行,否则将被跳过。

实现一个自定义BranchOperator

通过重写branches方法,返回一个分支函数与任务名称映射的字典,即可实现自定义

  1. from dbgpt.core.awel import DAG, BranchOperator, MapOperator
  2. def branch_even(x: int) -> bool:
  3. return x % 2 == 0
  4. def branch_odd(x: int) -> bool:
  5. return not branch_even(x)
  6. class MyBranchOperator(BranchOperator[int]):
  7. def __init__(self, even_task_name: str, odd_task_name: str, **kwargs):
  8. self.even_task_name = even_task_name
  9. self.odd_task_name = odd_task_name
  10. super().__init__(**kwargs)
  11. async def branches(self):
  12. return {
  13. branch_even: self.even_task_name,
  14. branch_odd: self.odd_task_name
  15. }
  16. with DAG("awel_branch_operator") as dag:
  17. task = MyBranchOperator(even_task_name="even_task", odd_task_name="odd_task")
  18. even_task = MapOperator(
  19. task_name="even_task",
  20. map_function=lambda x: print(f"{x} is even")
  21. )
  22. odd_task = MapOperator(
  23. task_name="odd_task",
  24. map_function=lambda x: print(f"{x} is odd")

样例

awel_tutorial 目录下,创建一个文件名为branch_operator_even_or_odd.py 的文件,内容如下:

  1. import asyncio
  2. from dbgpt.core.awel import (
  3. DAG, BranchOperator, MapOperator, JoinOperator,
  4. InputOperator, SimpleCallDataInputSource,
  5. is_empty_data
  6. )
  7. def branch_even(x: int) -> bool:
  8. return x % 2 == 0
  9. def branch_odd(x: int) -> bool:
  10. return not branch_even(x)
  11. branch_mapping = {
  12. branch_even: "even_task",
  13. branch_odd: "odd_task"
  14. }
  15. def even_func(x: int) -> int:
  16. print(f"Branch even, {x} is even, multiply by 10")
  17. return x * 10
  18. def odd_func(x: int) -> int:
  19. print(f"Branch odd, {x} is odd, multiply by itself")
  20. return x * x
  21. def combine_function(x: int, y: int) -> int:
  22. print(f"Received {x} and {y}")
  23. # Return the first non-empty data
  24. return x if not is_empty_data(x) else y
  25. with DAG("awel_branch_operator") as dag:
  26. input_task = InputOperator(input_source=SimpleCallDataInputSource())
  27. task = BranchOperator(branches=branch_mapping)
  28. even_task = MapOperator(task_name="even_task", map_function=even_func)
  29. odd_task = MapOperator(task_name="odd_task", map_function=odd_func)
  30. join_task = JoinOperator(combine_function=combine_function)
  31. input_task >> task >> even_task >> join_task
  32. input_task >> task >> odd_task >> join_task
  33. print("First call, input is 5")
  34. assert asyncio.run(join_task.call(call_data=5)) == 25
  35. print("=" * 80)
  36. print("Second call, input is 6")
  37. assert asyncio.run(join_task.call(call_data=6)) == 60

运行上述代码,查看程序输出

  1. poetry run python awel_tutorial/branch_operator_even_or_odd.py
  2. First call, input is 5
  3. Branch odd, 5 is odd, multiple by itself
  4. Received EmptyData(SKIP_DATA) by 25
  5. ================================================================================
  6. Second call, input is 6
  7. Branch even, 6 is even, multiply by 10
  8. Received 60 by EmptyData(SKIP_DATA)

DAG图如下所示

Branch Operator - 图1

在上述案例中,BranchOperator 算子有两个子任务,even_taskodd_task, 它将根据输入数据决定运行哪个分支。

同时我们也是用JoinOperator 算子来组合两个子任务,如果一个路径被跳过,JoinOperator算子将接收到一个 EmptyData(SKIP_DATA)作为输入数据,我们可以通过dbgpt.core.awel.is_empty_data方法来检测是否时一个空数据。