# coding: utf8
import os
import re
from typing import Any, Dict, Optional
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from custom_hooks._impala import ImpalaHook
from airflow.utils import operator_helpers
from airflow.utils.operator_helpers import context_to_airflow_vars
class ImpalaOperator(BaseOperator):
template_fields = ('hql', 'impala_conn_id')
template_ext = ('.hql', '.sql')
ui_color = '#f0e4ec'
def __init__(
self,
*,
hql: str = None,
impala_conn_id: str = 'impala_default',
db: str = 'default',
hql_file: Optional[str] = None,
hql_name: Optional[str] = None,
impalaconf_jinja_translate: bool = False,
run_as_owner: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.hql = hql
self.impala_conn_id = impala_conn_id
self.db = db
self.hql_file = hql_file
self.hql_name = hql_name
self.impalaconf_jinja_translate = impalaconf_jinja_translate
self.run_as = None
if run_as_owner:
self.run_as = self.dag.owner
# assigned lazily - just for consistency we can create the attribute with a
# `None` initial value, later it will be populated by the execute method.
# This also makes `on_kill` implementation consistent since it assumes `self.hook`
# is defined.
self.hook: Optional[ImpalaHook] = None
def get_hook(self) -> ImpalaHook:
"""Get Impala hook"""
return ImpalaHook()
def _parse_sqlfile(self):
with open(self.hql_file, 'r', encoding="utf8") as f:
if self.impalaconf_jinja_translate:
file_content = re.sub(r"(\$\{([ batch_date|run_date|BATCH_DATE|RUN_DATE]*)\})",
r"{{ ds }}", f.read())
else:
file_content = f.read()
pattern = r"--\[(.*?)\](.*?)\n--\[end\].*?"
sqls_dict = dict([(k, v) for k, v in re.findall(pattern, file_content, re.S) if v != ""])
if self.hql_name:
return sqls_dict[self.hql_name]
elif len(sqls_dict.keys()) == 1:
key = list(sqls_dict.keys())[0]
return sqls_dict[key]
else:
raise AirflowException("You must specify `hql_name` when `hql_file` is defined!")
def prepare_template(self) -> None:
if self.impalaconf_jinja_translate:
if self.hql_file is not None:
self.hql = self._parse_sqlfile()
else:
self.hql = re.sub(r"(\$\{([ a-zA-Z0-9_|batch_date|run_date|BATCH_DATE|RUN_DATE|]*)\})",
r"{{ ds }}", self.hql)
else:
if self.hql_file is not None:
self.hql = self._parse_sqlfile()
def execute(self, context: Dict[str, Any]) -> None:
# self.log.info('Executing: %s', self.hql)
self.hook = self.get_hook()
self.hook.run_hql(self.hql)
self.log.info("Executed successfully!")
def dry_run(self) -> None:
# Reset airflow environment variables to prevent
# existing env vars from impacting behavior.
self.clear_airflow_vars()
self.hook = self.get_hook()
self.hook.test_hql(hql=self.hql)
def on_kill(self) -> None:
if self.hook:
self.hook.kill()
def clear_airflow_vars(self) -> None:
"""Reset airflow environment variables to prevent existing ones from impacting behavior."""
blank_env_vars = {
value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
os.environ.update(blank_env_vars)