# coding: utf8
import contextlib
import os
import re
import socket
import subprocess
import time
import pandas
import unicodecsv as csv
from collections import OrderedDict
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any, Dict, List, Optional, Union
from impala.dbapi import connect
from impala.error import ProgrammingError
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.hooks.dbapi import DbApiHook
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook
def get_context_from_env_var() -> Dict[Any, Any]:
"""
Extract context from env variable, e.g. dag_id, task_id and execution_date,
so that they can be used inside BashOperator and PythonOperator.
:return: The context of interest.
"""
return {
format_map['default']: os.environ.get(format_map['env_var_format'], '')
for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
class ImpalaHook(DbApiHook):
"""
Wrapper around the impyla library
Notes:
* the default authMechanism is PLAIN, to override it you
can specify it in the ``extra`` of your connection in the UI
* the default for run_set_variable_statements is true, if you
are using impala you may need to set it to false in the
``extra`` of your connection in the UI
:param impala_conn_id: Reference to the
:type impala_conn_id: str
:param schema: Impala database name.
:type schema: Optional[str]
"""
conn_name_attr = 'impala_conn_id'
default_conn_name = 'impala_default'
conn_type = 'impala'
hook_name = 'Impala Thrift'
supports_autocommit = False
def get_conn(self, schema: Optional[str] = None) -> Any:
"""Returns a Impala connection object."""
username: Optional[str] = None
password: Optional[str] = None
db = self.get_connection(self.impala_conn_id)
auth_mechanism = db.extra_dejson.get('auth_mechanism', 'PLAIN')
password = db.password
return connect(
host=db.host,
port=db.port,
auth_mechanism=auth_mechanism,
user=db.login or username,
password=password,
database=db.schema or schema or 'default',
)
def _get_results(
self,
hql: Union[str, str, List[str]],
schema: str = 'default',
fetch_size: Optional[int] = None,
impala_conf: Optional[Dict[Any, Any]] = None,
) -> Any:
if isinstance(hql, str):
hql = [hql_ for hql_ in hql.split(";") if hql_ != ""]
previous_description = None
with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur:
cur.arraysize = fetch_size or 1000
# not all query services (e.g. impala AIRFLOW-4434) support the set command
db = self.get_connection(self.impala_conn_id) # type: ignore
if db.extra_dejson.get('run_set_variable_statements', True):
env_context = get_context_from_env_var()
if impala_conf:
env_context.update(impala_conf)
for statement in hql:
self.log.info(statement)
cur.execute(statement)
# we only get results of statements that returns
lowered_statement = statement.lower().strip()
if (
lowered_statement.startswith('select')
or lowered_statement.startswith('with')
or lowered_statement.startswith('show')
):
description = cur.description
if previous_description and previous_description != description:
message = '''The statements are producing different descriptions:
Current: {}
Previous: {}'''.format(
repr(description), repr(previous_description)
)
raise ValueError(message)
elif not previous_description:
previous_description = description
yield description
try:
# DB API 2 raises when no results are returned
# we're silencing here as some statements in the list
# may be `SET` or DDL
yield from cur
except ProgrammingError:
self.log.debug("get_results returned no records")
def get_results(
self,
hql: Union[str, str],
schema: str = 'default',
fetch_size: Optional[int] = None,
impala_conf: Optional[Dict[Any, Any]] = None,
) -> Dict[str, Any]:
"""
Get results of the provided hql in target schema.
:param hql: hql to be executed.
:type hql: str or list
:param schema: target schema, default to 'default'.
:type schema: str
:param fetch_size: max size of result to fetch.
:type fetch_size: int
:param impala_conf: impala_conf to execute alone with the hql.
:type impala_conf: dict
:return: results of hql execution, dict with data (list of results) and header
:rtype: dict
"""
results_iter = self._get_results(hql, schema, fetch_size=fetch_size, impala_conf=impala_conf)
header = next(results_iter)
results = {'data': list(results_iter), 'header': header}
return results
def to_csv(
self,
hql: Union[str, str],
csv_filepath: str,
schema: str = 'default',
delimiter: str = ',',
lineterminator: str = '\r\n',
output_header: bool = True,
fetch_size: int = 1000,
impala_conf: Optional[Dict[Any, Any]] = None,
) -> None:
"""
Execute hql in target schema and write results to a csv file.
:param hql: hql to be executed.
:type hql: str or list
:param csv_filepath: filepath of csv to write results into.
:type csv_filepath: str
:param schema: target schema, default to 'default'.
:type schema: str
:param delimiter: delimiter of the csv file, default to ','.
:type delimiter: str
:param lineterminator: lineterminator of the csv file.
:type lineterminator: str
:param output_header: header of the csv file, default to True.
:type output_header: bool
:param fetch_size: number of result rows to write into the csv file, default to 1000.
:type fetch_size: int
:param impala_conf: impala_conf to execute alone with the hql.
:type impala_conf: dict
"""
results_iter = self._get_results(hql, schema, fetch_size=fetch_size, impala_conf=impala_conf)
header = next(results_iter)
message = None
i = 0
with open(csv_filepath, 'wb') as file:
writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding='utf-8')
try:
if output_header:
self.log.debug('Cursor description is %s', header)
writer.writerow([c[0] for c in header])
for i, row in enumerate(results_iter, 1):
writer.writerow(row)
if i % fetch_size == 0:
self.log.info("Written %s rows so far.", i)
except ValueError as exception:
message = str(exception)
if message:
# need to clean up the file first
os.remove(csv_filepath)
raise ValueError(message)
self.log.info("Done. Loaded a total of %s rows.", i)
def get_records(
self,
hql: Union[str, str],
schema: str = 'default',
impala_conf: Optional[Dict[Any, Any]] = None
) -> Any:
"""
Get a set of records from a Impala query.
:param hql: hql to be executed.
:type hql: str or list
:param schema: target schema, default to 'default'.
:type schema: str
:param impala_conf: impala_conf to execute alone with the hql.
:type impala_conf: dict
:return: result of impala execution
:rtype: list
>>> hh = ImpalaHook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
>>> len(hh.get_records(sql))
100
"""
return self.get_results(hql, schema=schema, impala_conf=impala_conf)['data']
def get_pandas_df(
self,
hql: Union[str, str],
schema: str = 'default',
impala_conf: Optional[Dict[Any, Any]] = None,
**kwargs,
) -> pandas.DataFrame:
"""
Get a pandas dataframe from a Impala query
:param hql: hql to be executed.
:type hql: str or list
:param schema: target schema, default to 'default'.
:type schema: str
:param impala_conf: impala_conf to execute alone with the hql.
:type impala_conf: dict
:param kwargs: (optional) passed into pandas.DataFrame constructor
:type kwargs: dict
:return: result of impala execution
:rtype: DataFrame
>>> hh = ImpalaHook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
>>> df = hh.get_pandas_df(sql)
>>> len(df.index)
100
:return: pandas.DateFrame
"""
res = self.get_results(hql, schema=schema, impala_conf=impala_conf)
df = pandas.DataFrame(res['data'], **kwargs)
df.columns = [c[0] for c in res['header']]
return df
def run_hql(
self,
hql: Union[str, str],
schema: str = 'default',
impala_conf: Optional[Dict[Any, Any]] = None
) -> None:
if isinstance(hql, str):
hql = [hql_ for hql_ in hql.split(";") if hql_ != ""]
with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur:
# not all query services (e.g. impala AIRFLOW-4434) support the set command
db = self.get_connection(self.impala_conn_id) # type: ignore
if db.extra_dejson.get('run_set_variable_statements', True):
env_context = get_context_from_env_var()
if impala_conf:
env_context.update(impala_conf)
for statement in hql:
self.log.info("RUN HQL [%s (...)]", statement if len(statement) < 1000 else statement[:1000])
cur.execute(statement)
def test_hql(self,
hql: Union[str, str],
schema: str = 'default',
impala_conf: Optional[Dict[Any, Any]] = None
) -> None:
"""Test an hql statement using the Impala cusor and EXPLAIN"""
create, insert, select = [], [], []
for query in hql.split(';'):
query_original = query
query = query.lower().strip()
if query.startswith('create table'):
create.append(query_original)
elif query.startswith('insert'):
insert.append(query_original)
elif query.startswith('select'):
select.append(query_original)
for query_set in [create, insert, select]:
for query in query_set:
query = 'explain ' + query
try:
# self.get_results(hql, schema=schema, impala_conf=impala_conf)['data']
self.run_hql(query, schema=schema, impala_conf=impala_conf)
except AirflowException as e:
message = e.args[0].split('\n')[-2]
self.log.info(message)
error_loc = re.search(r'(\d+):(\d+)', message)
if error_loc and error_loc.group(1).isdigit():
lst = int(error_loc.group(1))
begin = max(lst - 2, 0)
end = min(lst + 3, len(query.split('\n')))
context = '\n'.join(query.split('\n')[begin:end])
self.log.info("Context :\n %s", context)
else:
self.log.info("SUCCESS")
def load_df(
self,
df: pandas.DataFrame,
table: str,
hdfspath: str = "/external/user/yumingmin",
field_dict: Optional[Dict[Any, Any]] = None,
delimiter: str = ',',
encoding: str = 'utf8',
pandas_kwargs: Any = None,
**kwargs: Any,
) -> None:
"""
Loads a pandas DataFrame into hive.
Hive data types will be inferred if not passed but column names will
not be sanitized.
:param df: DataFrame to load into a Hive table
:type df: pandas.DataFrame
:param table: target Hive table, use dot notation to target a
specific database
:type table: str
:param field_dict: mapping from column name to hive data type.
Note that it must be OrderedDict so as to keep columns' order.
:type field_dict: collections.OrderedDict
:param delimiter: field delimiter in the file
:type delimiter: str
:param encoding: str encoding to use when writing DataFrame to file
:type encoding: str
:param pandas_kwargs: passed to DataFrame.to_csv
:type pandas_kwargs: dict
:param kwargs: passed to self.load_file
"""
def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]:
dtype_kind_hive_type = {
'b': 'BOOLEAN', # boolean
'i': 'BIGINT', # signed integer
'u': 'BIGINT', # unsigned integer
'f': 'DOUBLE', # floating-point
'c': 'STRING', # complex floating-point
'M': 'TIMESTAMP', # datetime
'O': 'STRING', # object
'S': 'STRING', # (byte-)string
'U': 'STRING', # Unicode
'V': 'STRING', # void
}
order_type = OrderedDict()
for col, dtype in df.dtypes.iteritems():
order_type[col] = dtype_kind_hive_type[dtype.kind]
return order_type
if pandas_kwargs is None:
pandas_kwargs = {}
with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir, mode="w") as f:
if field_dict is None:
field_dict = _infer_field_types_from_df(df)
df.to_csv(
path_or_buf=f,
sep=delimiter,
header=False,
index=False,
encoding=encoding,
chunksize=10000,
date_format="%Y-%m-%d %H:%M:%S",
**pandas_kwargs,
)
f.flush()
return self.load_file(
filepath=f.name,
table=table,
hdfspath=hdfspath,
delimiter=delimiter,
field_dict=field_dict,
recreate=True,
tblproperties={"EXTERNAL": "TRUE"},
**kwargs
)
def load_file(
self,
filepath: str,
table: str,
hdfspath: str = "/external/user/yumingmin",
delimiter: str = ",",
field_dict: Optional[Dict[Any, Any]] = None,
create: bool = True,
overwrite: bool = True,
partition: Optional[Dict[str, Any]] = None,
recreate: bool = False,
tblproperties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Loads a local file into Impala(HDFS)
Note that the table generated in Impala uses ``STORED AS textfile``
which isn't the most efficient serialization format. If a
large amount of data is loaded and/or if the tables gets
queried considerably, you may want to use this operator only to
stage the data into a temporary table before loading it into its
final destination using a ``ImpalaOperator``.
:param filepath: local filepath of the file to load
:type filepath: str
:param hdfspath: local filepath of the file to load
:type hdfspath: str
:param table: target Impala table, use dot notation to target a
specific database
:type table: str
:param delimiter: field delimiter in the file
:type delimiter: str
:param field_dict: A dictionary of the fields name in the file
as keys and their Impala types as values.
Note that it must be OrderedDict so as to keep columns' order.
:type field_dict: collections.OrderedDict
:param create: whether to create the table if it doesn't exist
:type create: bool
:param overwrite: whether to overwrite the data in table or partition
:type overwrite: bool
:param partition: target partition as a dict of partition columns
and values
:type partition: dict
:param recreate: whether to drop and recreate the table at every
execution
:type recreate: bool
:param tblproperties: TBLPROPERTIES of the impala table being created
:type tblproperties: dict
"""
hql = ''
if recreate:
hql += f"\nDROP TABLE IF EXISTS {table};"
if create or recreate:
if field_dict is None:
raise ValueError("Must provide a field dict when creating a table")
fields = " " + ",\n ".join(f"`{k.strip('`')}` {v}" for k, v in field_dict.items())
hql += f"\nCREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"
if partition:
pfields = ",\n ".join(p + " STRING" for p in partition)
hql += f"PARTITIONED BY ({pfields})\n"
hql += "ROW FORMAT DELIMITED\n"
hql += f"FIELDS TERMINATED BY '{delimiter}'\n"
hql += f"STORED AS textfile LOCATION '{hdfspath}'\n"
if tblproperties is not None:
tprops = ", ".join(f"'{k}'='{v}'" for k, v in tblproperties.items())
hql += f"TBLPROPERTIES({tprops})\n"
hql += f"; \nCOMPUTE STATS {table}"
self.log.info(hql)
self.run_hql(hql)
# Upload a file using WebHDFSHook
WebHDFSHook(proxy_user="yumingmin").load_file(source=filepath, destination=hdfspath, overwrite=overwrite)
def kill(self) -> None:
"""Kill Hive cli command"""
if hasattr(self, 'sub_process'):
if self.sub_process.poll() is None:
print("Killing the Hive job")
self.sub_process.terminate()
time.sleep(60)
self.sub_process.kill()
class ImpalaMetastoreHook(ImpalaHook):
def check_for_partition(self, db: str = None, table: str = None, partition: str = None) -> bool:
"""
Checks whether a partition exists
:param db: Name of impala database @table belongs to
:type schema: str
:param table: Name of impala table @partition belongs to
:type table: str
:partition: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:type schema: str
:rtype: bool
>>> hh = ImpalaMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
True
"""
if '.' in table:
db, table = table.split('.')[:2]
if not self.table_exists(db=db, table=table):
raise Exception(f"{db}.{table} does not exist!")
partition_names = self.get_table_partiton_name(db=db, table=table)
if len(partition_names) == 0:
raise Exception(f"{db}.{table} is not partitioned table!")
elif partition.split("=")[0] not in partition_names:
partition = "%s.%s" % (partition_names[0], partition.split("=")[1])
partition = partition if "\'" in partition else "%s='%s'" % (partition_names[0], partition.split("=")[1])
hql = f"SELECT COUNT(1) AS rows_num FROM {db}.{table} WHERE 1=1 AND {partition}"
rows_num = self.get_pandas_df(hql)["rows_num"].values[0]
return True if rows_num > 0 else False
def get_table_partiton_name(self, table: str, db: str = 'default') -> list:
if '.' in table:
db, table = table.split('.')[:2]
if not self.table_exists(db=db, table=table):
raise Exception(f"{db}.{table} does not exist!")
hql = f"SHOW CREATE TABLE {db}.{table}"
ddl = self.get_pandas_df(hql)["result"].values[0]
if ddl.lower().find("partitioned by") == -1:
self.log.info(f"{db}.{table} is not partitioned table!")
return []
else:
partition_name = re.findall(r'partitioned by.\(\s+(.*) string.\)', ddl.lower(), re.S)
return partition_name
def get_table(self, table: str, db: str = 'default') -> Any:
"""Get a metastore table object"""
if '.' in table:
db, table = table.split('.')[:2]
hql = f"SELECT * FROM {db}.{table} LIMIT 1"
tbl_metastore = {}
try:
df = self.get_pandas_df(hql)
tbl_metastore[table] = {}
tbl_metastore[table]["columns"] = df.columns
return tbl_metastore
except Exception as e:
self.log.error(e)
return tbl_metastore
def get_tables(self, db: str, pattern: str = '*') -> Any:
"""Get a metastore table object"""
hql = f"USE {db}; \nSHOW TABLES LIKE '{pattern}'"
try:
df = self.get_pandas_df(hql)
except Exception as e:
return {}
if len(df) == 0:
return {}
else:
tbls_metastore = {}
for tb in df.name.tolist():
try:
tbls_metastore[tb] = self.get_table(db=db, table_name=tb)[tb]
except Exception as e:
tbls_metastore[tb] = {}
return tbls_metastore
def get_databases(self, db: str, pattern: str = '*') -> Any:
"""Get a metastore databases object"""
hql = f"\nSHOW DATABASES LIKE '{pattern}'"
try:
df = self.get_pandas_df(hql)
return df.name.columns
except Exception as e:
return []
def get_partitions(self,
db: str,
table: str,
partition_filter: Optional[Dict[Any, Any]] = None
) -> List[Any]:
"""
Returns a list of all partitions in a table. Works only
for tables with less than 32767 (java short max val).
For subpartitioned table, the number might easily exceed this.
>>> hh = ImpalaMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> parts = hh.get_partitions(schema='airflow', table=t)
>>> len(parts)
1
>>> parts
[{'ds': '2015-01-01'}]
"""
if '.' in table:
db, table = table.split('.')[:2]
partition = self.get_table_partiton_name(db=db, table=table)
if not partition:
raise Exception(f"{db}.{table} is not partition table!")
hql = f"SELECT DISTINCT {partition[0]} AS value FROM {db}.{table} WHERE 1=1"
if "part_le" in partition_filter.keys():
hql += f" AND {partition[0]} <= '{partition_filter['part_le']}'"
elif "part_ge" in partition_filter.keys():
hql += f" AND {partition[0]} >= '{partition_filter['part_ge']}'"
elif "part_notin" in partition_filter.keys():
cond = str(partition_filter['part_notin']).strip('[').strip(']')
hql += f" AND {partition[0]} NOT IN ({cond})"
df_parts = self.get_pandas_df(hql)
return [{partition[0]: self.get_pandas_df(hql)['value'].tolist()}]
def max_partition(self, db: str, table: str) -> Any:
"""
Returns the maximum value for all partitions with given field in a table.
If only one partition key exist in the table, the key will be used as field.
filter_map should be a partition_key:partition_value map and will be used to
filter out partitions.
:param db: schema name.
:type db: str
:param table_name: table name.
:type table_name: str
>>> hh = HiveMetastoreHook()
>>> filter_map = {'ds': '2015-01-01'}
>>> t = 'static_babynames_partitioned'
>>> hh.max_partition(db='airflow', table=t)
'2015-01-01'
"""
if '.' in table:
db, table = table.split('.')[:2]
partition = self.get_table_partiton_name(db=db, table=table)
if not partition:
raise Exception(f"{db}.{table} is not partition table!")
hql = f"SELECT MAX({partition[0]}) AS value FROM {db}.{table}"
max_part = self.get_pandas_df(hql)["value"].values[0]
self.log.info(f"{db}.{table} max partition is: {max_part}")
return max_part
def table_exists(self, table: str, db: str = 'default') -> bool:
"""
Check if table exists
>>> hh = ImalaMetastoreHook()
>>> hh.table_exists(db='airflow', table_name='static_babynames')
True
>>> hh.table_exists(db='airflow', table_name='does_not_exist')
False
"""
if '.' in table:
db, table = table.split('.')[:2]
try:
if self.get_table(table, db):
return True
except Exception:
return False
def drop_partitions(self, table, part_vals, delete_data=False, db='default'):
"""
Drop partitions from the given table matching the part_vals input
:param table: table name.
:type table: str
:param part_vals: list of partition specs.
:type part_vals: list
:param delete_data: Setting to control if underlying data have to deleted
in addition to dropping partitions.
:type delete_data: bool
:param db: Name of impala schema (database) @table belongs to
:type db: str
>>> hh = ImpalaMetastoreHook()
>>> hh.drop_partitions(db='airflow', table_name='static_babynames',
part_vals="['2020-05-01']")
True
"""
partition = self.get_table_partiton_name(db=db, table=table)
if not partition:
raise Exception(f"{db}.{table} is not partition table!")
for pval in part_vals:
hql = f"ALTER TABLE {db}.{table} DROP IF EXISTS PARTITION ({partition[0]}='{pval}')"
self.run_hql(hql)
def add_partitions(self, table, part_vals, delete_data=False, db='default'):
"""
Add partitions from the given table matching the part_vals input
:param table: table name.
:type table: str
:param part_vals: list of partition specs.
:type part_vals: list
:param delete_data: Setting to control if underlying data have to deleted
in addition to dropping partitions.
:type delete_data: bool
:param db: Name of impala schema (database) @table belongs to
:type db: str
>>> hh = ImpalaMetastoreHook()
>>> hh.add_partitions(db='airflow', table_name='static_babynames',
part_vals="['2020-05-01']")
True
"""
if '.' in table:
db, table = table.split('.')[:2]
partition = self.get_table_partiton_name(db=db, table=table)
if not partition:
raise Exception(f"{db}.{table} is not partition table!")
self.drop_partitions(db=db, table=table, part_vals=part_vals)
for pval in part_vals:
hql = f"ALTER TABLE {db}.{table} ADD PARTITION ({partition[0]}='{pval}')"
self.run_hql(hql)
def refresh_table(self, db: str, table: str) -> bool:
if '.' in table:
db, table = table.split('.')[:2]
if not self.table_exists(db=db, table=table):
raise AirflowException(f"{db}.{table} not exists!")
hql = f"INVALIDATE METADATA {db}.{table}"
self.run_hql(hql)
TO-DO:
- 完善 load_df 功能