# coding: utf8import osimport refrom typing import Any, Dict, Optionalfrom airflow.configuration import conffrom airflow.exceptions import AirflowExceptionfrom airflow.models import BaseOperatorfrom custom_hooks._impala import ImpalaHookfrom airflow.utils import operator_helpersfrom airflow.utils.operator_helpers import context_to_airflow_varsclass ImpalaOperator(BaseOperator): template_fields = ('hql', 'impala_conn_id') template_ext = ('.hql', '.sql') ui_color = '#f0e4ec' def __init__( self, *, hql: str = None, impala_conn_id: str = 'impala_default', db: str = 'default', hql_file: Optional[str] = None, hql_name: Optional[str] = None, impalaconf_jinja_translate: bool = False, run_as_owner: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.hql = hql self.impala_conn_id = impala_conn_id self.db = db self.hql_file = hql_file self.hql_name = hql_name self.impalaconf_jinja_translate = impalaconf_jinja_translate self.run_as = None if run_as_owner: self.run_as = self.dag.owner # assigned lazily - just for consistency we can create the attribute with a # `None` initial value, later it will be populated by the execute method. # This also makes `on_kill` implementation consistent since it assumes `self.hook` # is defined. self.hook: Optional[ImpalaHook] = None def get_hook(self) -> ImpalaHook: """Get Impala hook""" return ImpalaHook() def _parse_sqlfile(self): with open(self.hql_file, 'r', encoding="utf8") as f: if self.impalaconf_jinja_translate: file_content = re.sub(r"(\$\{([ batch_date|run_date|BATCH_DATE|RUN_DATE]*)\})", r"{{ ds }}", f.read()) else: file_content = f.read() pattern = r"--\[(.*?)\](.*?)\n--\[end\].*?" sqls_dict = dict([(k, v) for k, v in re.findall(pattern, file_content, re.S) if v != ""]) if self.hql_name: return sqls_dict[self.hql_name] elif len(sqls_dict.keys()) == 1: key = list(sqls_dict.keys())[0] return sqls_dict[key] else: raise AirflowException("You must specify `hql_name` when `hql_file` is defined!") def prepare_template(self) -> None: if self.impalaconf_jinja_translate: if self.hql_file is not None: self.hql = self._parse_sqlfile() else: self.hql = re.sub(r"(\$\{([ a-zA-Z0-9_|batch_date|run_date|BATCH_DATE|RUN_DATE|]*)\})", r"{{ ds }}", self.hql) else: if self.hql_file is not None: self.hql = self._parse_sqlfile() def execute(self, context: Dict[str, Any]) -> None: # self.log.info('Executing: %s', self.hql) self.hook = self.get_hook() self.hook.run_hql(self.hql) self.log.info("Executed successfully!") def dry_run(self) -> None: # Reset airflow environment variables to prevent # existing env vars from impacting behavior. self.clear_airflow_vars() self.hook = self.get_hook() self.hook.test_hql(hql=self.hql) def on_kill(self) -> None: if self.hook: self.hook.kill() def clear_airflow_vars(self) -> None: """Reset airflow environment variables to prevent existing ones from impacting behavior.""" blank_env_vars = { value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() } os.environ.update(blank_env_vars)