简单实现表血缘图

Database and Ruby, Python, History


在 ETL 过程中,我们时常需要知道这些数据是从哪些表的哪个字段来的。我们可以写一个文档,但文档需要时刻维护。如果是用一些 ETL 工具,可是直接画出这些关系,即数据血缘。在我的项目中,由于是自己造轮子,所以需要自己画数据血缘,如下图所示。

自己写了一段代码,主要是通过解析 sql 语句,然后在用 Python 的 DOT language 来实现的。

实现代码


    import re
    import os
    from collections import defaultdict
    import json
    import pydot


    def parse_relationships_from_sql_files(sql_folder_path):
        full_relationships = []
        for root, _, files in os.walk(sql_folder_path):
            print(files)
            for file_name in files:
                if file_name.endswith('.sql'):
                    print(file_name)
                    full_relationships += parse_relationships_from_sql_file(sql_file_path=f"{root}/{file_name}")
        return full_relationships


    def parse_relationships_from_sql_file(sql_file_path):
        """
        return:
            [
              {'staging': {'dim_customers': [('app', 'account')]}},
              {'data': {'dim_customers': [('staging', 'dim_customers'),
                  ('data', 'dim_customers')]}},
              {'data': {'dim_customers': [('staging', 'dim_customers')]}}
            ]
        """
        sql = ''
        sql_raw_relationships = []
        with open(sql_file_path) as file:
            sql = file.read()

        sql_statements = parse_sql_statements(sql)

        for statement in sql_statements:
            relationship = parse_insert_statement(statement)
            sql_raw_relationships.append(relationship) if relationship is not None else None
            relationship = parse_update_statement(statement)
            sql_raw_relationships.append(relationship) if relationship is not None else None

        return sql_raw_relationships


    def parse_sql_statements(sql):
        sql_match = re.search(r"CREATE.+FUNCTION[\s\S]*(?<=BEGIN)([\s\S]*)(?=END)", sql)
        sql_statements = [x.strip() for x in sql_match.group(1).split(';')] if sql_match else []

        return sql_statements


    def parse_insert_statement(sql):
        """
        return:
            {'data': {'dim_customers': [('staging', 'dim_customers')]}}
        """
        # Extract target table and columns
        target_table_match = re.search(r"INSERT INTO\s*(\w+\.\w+)\s*\(([^)]*)\)", sql, re.IGNORECASE)
        target_table = None
        target_columns = None
        if target_table_match:
            target_table = target_table_match.group(1)
            target_columns = [col.strip() for col in target_table_match.group(2).split(',')]
        else:
            return None

        # Extract source table and columns
        source_tables = list()
        source_table_match = re.findall(r"FROM\s+(\w+\.\w+)\s(\w+)", sql, re.IGNORECASE)
        if source_table_match:
            source_tables += source_table_match

        join_tables = re.findall(r"JOIN\s+(\w+\.\w+)\s(\w+)", sql)
        if join_tables:
            source_tables += join_tables

        source_columns_match = re.search(r"SELECT\s+(.+)\s+FROM", sql, flags=re.DOTALL | re.IGNORECASE)
        source_columns_str = source_columns_match.group(1).strip() if source_columns_match else ""
        source_columns = [col.strip() for col in source_columns_str.split(',')]

        if target_columns and source_columns:
            pass

        schema, table = target_table.split('.')

        relationship = {
            schema: {
                table: [tuple(tbl.split('.')) for tbl, _ in source_tables]
            }
        }

        return relationship


    def parse_update_statement(sql):
        """
        return:
            {'data': {'dim_customers': [('staging', 'dim_customers')]}}
        """
        # Extract target table and columns
        target_table_match = re.search(r"UPDATE\s+(\w+\.\w+)", sql, re.IGNORECASE)
        target_table = None
        if target_table_match:
            target_table = target_table_match.group(1)
        else:
            return None

        # Extract source table and columns
        source_tables = list()
        source_table_match = re.findall(r"FROM\s+(\w+\.\w+)\s(\w+)", sql, re.IGNORECASE)
        if source_table_match:
            source_tables += source_table_match

        join_tables = re.findall(r"JOIN\s+(\w+\.\w+)\s(\w+)", sql)
        if join_tables:
            source_tables += join_tables

        schema, table = target_table.split('.')

        relationship = {
            schema: {
                table: [tuple(tbl.split('.')) for tbl, _ in source_tables]
            }
        }

        return relationship


    def merge_relationships(relationships):
        """
        return:
        {
          'app': {'account': []},
          'staging': {'dim_customers': [('app', 'account')]},
          'data': {'dim_customers': [('staging', 'dim_customers')]}
        }
        """
        merged_relationships = defaultdict(lambda: defaultdict(list))
        for relationship in relationships:
            for schema, tables in relationship.items():
                for table, sources in tables.items():
                    for source in sources:
                        src_schema, src_table = source
                        # skip self table to self table scenario
                        if (src_schema == schema) and (src_table == table):
                            continue
                        # just initialize this node
                        if merged_relationships[src_schema][src_table]:
                            pass
                        if source not in merged_relationships[schema][table]:
                            merged_relationships[schema][table].append(source)

        return merged_relationships


    def generate_dot_for_relationships(relationships):
        dot_content = "digraph dw { rankdir=LR;\n graph [pad=2.0, nodesep=0.5, ranksep=4];\n splines=true;\n  colorscheme=dark28;\n  "
        relationship_dot = []
        # draw the nodes for each schema as subgraph
        for index, (schema, tables) in enumerate(relationships.items()):
            schema_dot = f"subgraph cluster_{schema} " + '{' + "label=\"{schema}\"; shape=box; color={index+1}; "
            for table, sources in tables.items():
                schema_dot += f" {schema}_{table}[shape=box,label=\"{schema}.{table}\"]; "
                for src_schema, src_table in sources:
                    relationship_dot.append((src_schema, src_table, schema, table))

            schema_dot += "}\n"
            dot_content += schema_dot

        # draw the arrow between nodes
        for idx, r_dot in enumerate(relationship_dot):
            src_schema, src_table, schema, table = r_dot
            # idx = list(relationships.keys()).index(src_schema)

            dot_content += f"\n{src_schema}_{src_table} -> {schema}_{table}[color=\"/dark28/{(idx % 8)+1}\"];"
        dot_content += "}"

        return dot_content


    def export_dot(dot_content, output_file):
        graphs = pydot.graph_from_dot_data(dot_content)
        graph = graphs[0]
        graph.write_png(output_file)


    folder_path = 'sqls/'
    relationships = merge_relationships(parse_relationships_from_sql_files(folder_path))
    print(json.dumps(relationships, indent=4))
    dot_content = generate_dot_for_relationships(relationships)
    export_dot(dot_content, 'assets/images/dw_data_flow.png')