import os import json import duckdb import gradio as gr import pandas as pd import pandera as pa from pandera import Column import ydata_profiling as pp from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace from langsmith import traceable from langchain import hub import warnings import dlt warnings.filterwarnings("ignore", category=DeprecationWarning) # Height of the Tabs Text Area TAB_LINES = 8 #----------CONNECT TO DATABASE---------- md_token = os.getenv('MD_TOKEN') conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) #--------------------------------------- #-------LOAD HUGGINGFACE------- models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"] model_loaded = False for model in models: try: endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192) info = endpoint.client.get_endpoint_info() model_loaded = True break except Exception as e: print(f"Error for model {model}: {e}") continue llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=8192) #--------------------------------------- #-----LOAD PROMPT FROM LANCHAIN HUB----- prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow") prompt_user_input = hub.pull("usergenerate-rules-testworkflow") #--------------ALL UTILS---------------- # Get Databases def get_schemas(): schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] # Get Tables def get_tables_names(schema_name): tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() return [table[0] for table in tables] # Update Tables def update_table_names(schema_name): tables = get_tables_names(schema_name) return gr.update(choices=tables) # def get_data_df(schema): # print('Getting Dataframe from the Database') # return conn.sql(f"SELECT * FROM {schema} LIMIT 1000") @dlt.resource def fetch_data(schema): result = conn.sql(f"SELECT * FROM {schema} LIMIT 1000") while True: chunk_df = result.fetch_df_chunk(2) if chunk_df is None or len(chunk_df) == 0: break else: yield chunk_df def create_pipeline(schema): dataset_name = schema.split('.')[1] print("Dataset Name: ", dataset_name) table_name = schema.split('.')[2] print("Table Name: ", table_name) pipeline =dlt.pipeline( pipeline_name='duckdb_pipeline', destination='duckdb', dataset_name= dataset_name, ) load_info = pipeline.run(fetch_data(schema), table_name = table_name, write_disposition = "replace") print(load_info) return dataset_name + "." + table_name def load_pipeline(table_name): _conn = duckdb.connect("duckdb_pipeline.duckdb") return _conn, _conn.sql(f"SELECT * FROM {table_name} LIMIT 1000").df() def df_summary(df): summary = [] for column in df.columns: if pd.api.types.is_numeric_dtype(df[column]): summary.append({ "column": column, "max": df[column].max(), "min": df[column].min(), "count": df[column].count(), "nunique": df[column].nunique(), "dtype": str(df[column].dtype), "top": None }) elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None summary.append({ "column": column, "max": None, "min": None, "count": df[column].count(), "nunique": df[column].nunique(), "dtype": str(df[column].dtype), "top": top_value }) summary_df = pd.DataFrame(summary) return summary_df.reset_index(drop=True) def format_prompt(df): summary = df_summary(df) return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'), summary=summary.to_json(orient='records')) def format_user_prompt(df): return prompt_user_input.format_prompt(data=df.head().to_json(orient='records')) def process_inputs(inputs) : return {'input_query': inputs['messages'].to_messages()[1]} @traceable(process_inputs=process_inputs) def run_llm(messages): try: response = llm.invoke(messages) print(response.content.replace("```", "'''").replace("json", "")) tests = json.loads(response.content.replace("```", "").replace("json", "")) except Exception as e: return e return tests # Get Schema def get_table_schema(table): result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() ddl_create = result.iloc[0,0] parent_database = result.iloc[0,1] schema_name = result.iloc[0,2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return full_path def describe(df): numerical_info = pd.DataFrame() categorical_info = pd.DataFrame() if len(df.select_dtypes(include=['number']).columns) >= 1: numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index() numerical_info.rename(columns={'index': 'column'}, inplace=True) if len(df.select_dtypes(include=['object']).columns) >= 1: categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index() categorical_info.rename(columns={'index': 'column'}, inplace=True) return numerical_info, categorical_info def validate_pandera(tests, df): validation_results = [] for test in tests: column_name = test['column_name'] try: rule = eval(test['pandera_rule']) validated_column = rule(df[[column_name]]) validation_results.append({ "Columns": column_name, "Result": "✅ Pass" }) except Exception as e: validation_results.append({ "Columns": column_name, "Result": f"❌ Fail - {str(e)}" }) return pd.DataFrame(validation_results) def statistics(df): profile = pp.ProfileReport(df) report_dict = profile.get_description() description, alerts = report_dict.table, report_dict.alerts # Statistics mapping = { 'n': 'Number of observations', 'n_var': 'Number of variables', 'n_cells_missing': 'Number of cells missing', 'n_vars_with_missing': 'Number of columns with missing data', 'n_vars_all_missing': 'Columns with all missing data', 'p_cells_missing': 'Missing cells (%)', 'n_duplicates': 'Duplicated rows', 'p_duplicates': 'Duplicated rows (%)', } updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'} # Add flattened types information if 'Text' in description.get('types', {}): updated_data['Number of text columns'] = description['types']['Text'] if 'Categorical' in description.get('types', {}): updated_data['Number of categorical columns'] = description['types']['Categorical'] if 'Numeric' in description.get('types', {}): updated_data['Number of numeric columns'] = description['types']['Numeric'] if 'DateTime' in description.get('types', {}): updated_data['Number of datetime columns'] = description['types']['DateTime'] df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value']) df_statistics['Value'] = df_statistics['Value'].astype(int) # Alerts alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts] df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category']) return df_statistics, df_alerts #--------------------------------------- # Main Function def main(table): schema = get_table_schema(table) # Create dlt pipeline table_name = create_pipeline(schema) # Load dlt pipeline connection, df = load_pipeline(table_name) # df = get_data_df(schema) df_statistics, df_alerts = statistics(df) describe_num, describe_cat = describe(df) messages = format_prompt(df=df) tests = run_llm(messages) if isinstance(tests, Exception): tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([]) tests_df = pd.DataFrame(tests) tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) pandera_results = validate_pandera(tests, df) connection.close() return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results def user_results(table, text_query): schema = get_table_schema(table) # Create dlt pipeline table_name = create_pipeline(schema) # Load dlt pipeline connection, df = load_pipeline(table_name) messages = format_user_prompt(df=df, user_description=text_query) print(f'Generated Tests from user input: {tests}') if isinstance(tests, Exception): tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) return tests, pd.DataFrame([]) tests_df = pd.DataFrame(tests) tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) pandera_results = validate_pandera(tests, df) connection.close() return tests_df, pandera_results # Custom CSS styling custom_css = """ print('Validated Tests with Pandera') .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""
Dataset Test Workflow
Implement and Automate Data Validation Processes.
""") with gr.Row(): with gr.Column(scale=1): schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) with gr.Row(): generate_result = gr.Button("Validate Data", variant="primary") with gr.Column(scale=2): with gr.Tabs(): with gr.Tab("Description"): with gr.Row(): with gr.Column(): data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) with gr.Row(): with gr.Column(): describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) with gr.Column(): describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) with gr.Tab("Alerts"): data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) with gr.Tab("Rules & Validations"): tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) with gr.Tab("Data"): result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) with gr.Tab('Text to Validation'): with gr.Row(): query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.") with gr.Row(): with gr.Column(): pass with gr.Column(scale=1, min_width=50): user_generate_result = gr.Button("Validate Data", variant="primary" ) with gr.Row(): with gr.Column(): query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False) with gr.Column(): query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False) schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result]) if __name__ == "__main__": demo.launch(debug=True)