1. # coding: utf8
    2. import os
    3. import re
    4. from typing import Any, Dict, Optional
    5. from airflow.configuration import conf
    6. from airflow.exceptions import AirflowException
    7. from airflow.models import BaseOperator
    8. from custom_hooks._impala import ImpalaHook
    9. from airflow.utils import operator_helpers
    10. from airflow.utils.operator_helpers import context_to_airflow_vars
    11. class ImpalaOperator(BaseOperator):
    12. template_fields = ('hql', 'impala_conn_id')
    13. template_ext = ('.hql', '.sql')
    14. ui_color = '#f0e4ec'
    15. def __init__(
    16. self,
    17. *,
    18. hql: str = None,
    19. impala_conn_id: str = 'impala_default',
    20. db: str = 'default',
    21. hql_file: Optional[str] = None,
    22. hql_name: Optional[str] = None,
    23. impalaconf_jinja_translate: bool = False,
    24. run_as_owner: bool = False,
    25. **kwargs: Any,
    26. ) -> None:
    27. super().__init__(**kwargs)
    28. self.hql = hql
    29. self.impala_conn_id = impala_conn_id
    30. self.db = db
    31. self.hql_file = hql_file
    32. self.hql_name = hql_name
    33. self.impalaconf_jinja_translate = impalaconf_jinja_translate
    34. self.run_as = None
    35. if run_as_owner:
    36. self.run_as = self.dag.owner
    37. # assigned lazily - just for consistency we can create the attribute with a
    38. # `None` initial value, later it will be populated by the execute method.
    39. # This also makes `on_kill` implementation consistent since it assumes `self.hook`
    40. # is defined.
    41. self.hook: Optional[ImpalaHook] = None
    42. def get_hook(self) -> ImpalaHook:
    43. """Get Impala hook"""
    44. return ImpalaHook()
    45. def _parse_sqlfile(self):
    46. with open(self.hql_file, 'r', encoding="utf8") as f:
    47. if self.impalaconf_jinja_translate:
    48. file_content = re.sub(r"(\$\{([ batch_date|run_date|BATCH_DATE|RUN_DATE]*)\})",
    49. r"{{ ds }}", f.read())
    50. else:
    51. file_content = f.read()
    52. pattern = r"--\[(.*?)\](.*?)\n--\[end\].*?"
    53. sqls_dict = dict([(k, v) for k, v in re.findall(pattern, file_content, re.S) if v != ""])
    54. if self.hql_name:
    55. return sqls_dict[self.hql_name]
    56. elif len(sqls_dict.keys()) == 1:
    57. key = list(sqls_dict.keys())[0]
    58. return sqls_dict[key]
    59. else:
    60. raise AirflowException("You must specify `hql_name` when `hql_file` is defined!")
    61. def prepare_template(self) -> None:
    62. if self.impalaconf_jinja_translate:
    63. if self.hql_file is not None:
    64. self.hql = self._parse_sqlfile()
    65. else:
    66. self.hql = re.sub(r"(\$\{([ a-zA-Z0-9_|batch_date|run_date|BATCH_DATE|RUN_DATE|]*)\})",
    67. r"{{ ds }}", self.hql)
    68. else:
    69. if self.hql_file is not None:
    70. self.hql = self._parse_sqlfile()
    71. def execute(self, context: Dict[str, Any]) -> None:
    72. # self.log.info('Executing: %s', self.hql)
    73. self.hook = self.get_hook()
    74. self.hook.run_hql(self.hql)
    75. self.log.info("Executed successfully!")
    76. def dry_run(self) -> None:
    77. # Reset airflow environment variables to prevent
    78. # existing env vars from impacting behavior.
    79. self.clear_airflow_vars()
    80. self.hook = self.get_hook()
    81. self.hook.test_hql(hql=self.hql)
    82. def on_kill(self) -> None:
    83. if self.hook:
    84. self.hook.kill()
    85. def clear_airflow_vars(self) -> None:
    86. """Reset airflow environment variables to prevent existing ones from impacting behavior."""
    87. blank_env_vars = {
    88. value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
    89. }
    90. os.environ.update(blank_env_vars)