Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import duckdb | |
| import gradio as gr | |
| import pandas as pd | |
| import pandera as pa | |
| from pandera import Column | |
| import random | |
| from dataprep.eda import compute | |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
| from .utils import ( | |
| format_num_stats, format_cat_stats, | |
| format_ov_stats, format_insights | |
| ) | |
| from langsmith import traceable | |
| from langchain import hub | |
| import warnings | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # Height of the Tabs Text Area | |
| TAB_LINES = 8 | |
| #----------CONNECT TO DATABASE---------- | |
| md_token = os.getenv('MD_TOKEN') | |
| conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
| #--------------------------------------- | |
| #-------LOAD HUGGINGFACE------- | |
| models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", | |
| "meta-llama/Llama-3.1-70B-Instruct"] | |
| model_loaded = False | |
| for model in models: | |
| try: | |
| endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192) | |
| info = endpoint.client.get_endpoint_info() | |
| model_loaded = True | |
| break | |
| except Exception as e: | |
| print(f"Error for model {model}: {e}") | |
| continue | |
| llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=4096) | |
| #--------------------------------------- | |
| #-----LOAD PROMPT FROM LANCHAIN HUB----- | |
| prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow") | |
| prompt_user_input = hub.pull("usergenerate-rules-testworkflow") | |
| #--------------ALL UTILS---------------- | |
| # Get Databases | |
| def get_schemas(): | |
| schemas = conn.execute(""" | |
| SELECT DISTINCT schema_name | |
| FROM information_schema.schemata | |
| WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
| """).fetchall() | |
| return [item[0] for item in schemas] | |
| # Get Tables | |
| def get_tables_names(schema_name): | |
| tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
| return [table[0] for table in tables] | |
| # Update Tables | |
| def update_table_names(schema_name): | |
| tables = get_tables_names(schema_name) | |
| return gr.update(choices=tables) | |
| # Get Schema | |
| def get_table_schema(table): | |
| result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
| ddl_create = result.iloc[0,0] | |
| parent_database = result.iloc[0,1] | |
| schema_name = result.iloc[0,2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return full_path | |
| def get_data_df(schema): | |
| print('Getting Dataframe from the Database') | |
| return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df() | |
| def calcualte_stats(df): | |
| indev_stats = [] | |
| cols = [] | |
| _df = df.copy() | |
| num_cols = _df.select_dtypes(include=['number'], exclude=['datetime']).columns | |
| cat_cols = _df.select_dtypes(include=['object'], exclude=['datetime']).columns | |
| _all_stats = compute(_df) | |
| all_stats = format_ov_stats(_all_stats['stats']) | |
| insights = format_insights(_all_stats['overview_insights']) | |
| for i, col in enumerate(random.sample(num_cols.tolist()+cat_cols.tolist(), 2)): | |
| _indv_data = compute(_df, col) | |
| if col in cat_cols: | |
| indev_data_cat = format_cat_stats(_indv_data["data"]) | |
| indev_stats.append(pd.DataFrame([indev_data_cat['Overview']], index=[f'{col}_stats']).T) | |
| elif col in num_cols: | |
| try: | |
| indev_data_num = format_num_stats(_indv_data["data"]) | |
| except: | |
| indev_data_num = format_cat_stats(_indv_data["data"]) | |
| indev_stats.append(pd.DataFrame([indev_data_num['Overview']], index=[f'{col}_stats']).T) | |
| return { | |
| "overall_stats": pd.DataFrame(all_stats[0], index=['Dataset Statistics']).T, | |
| "insights": insights, | |
| "stats_1": indev_stats[0], | |
| "stats_2": indev_stats[1] | |
| } | |
| def df_summary(df): | |
| summary = [] | |
| for column in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[column]): | |
| summary.append({ | |
| "column": column, "max": df[column].max(), "min": df[column].min(), | |
| "count": df[column].count(), "nunique": df[column].nunique(), | |
| "dtype": str(df[column].dtype), "top": None | |
| }) | |
| elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): | |
| top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None | |
| summary.append({ | |
| "column": column, "max": None, "min": None, "count": df[column].count(), | |
| "nunique": df[column].nunique(), "dtype": str(df[column].dtype), "top": top_value | |
| }) | |
| summary_df = pd.DataFrame(summary) | |
| return summary_df.reset_index(drop=True) | |
| def format_prompt(df): | |
| summary = df_summary(df) | |
| return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'), | |
| summary=summary.to_json(orient='records')) | |
| def format_user_prompt(df): | |
| return prompt_user_input.format_prompt(data=df.head().to_json(orient='records')) | |
| def process_inputs(inputs) : | |
| return {'input_query': inputs['messages'].to_messages()[1]} | |
| def run_llm(messages): | |
| try: | |
| response = llm.invoke(messages) | |
| tests = json.loads(response.content) | |
| except Exception as e: | |
| return e | |
| return tests | |
| def validate_pandera(tests, df): | |
| validation_results = [] | |
| for test in tests: | |
| column_name = test['column_name'] | |
| try: | |
| rule = eval(test['pandera_rule']) | |
| validated_column = rule(df[[column_name]]) | |
| validation_results.append({ | |
| "Columns": column_name, | |
| "Result": "✅ Pass" | |
| }) | |
| except Exception as e: | |
| validation_results.append({ | |
| "Columns": column_name, | |
| "Result": f"❌ Fail - {str(e)}" | |
| }) | |
| return pd.DataFrame(validation_results) | |
| #--------------------------------------- | |
| # Main Function | |
| def main(table): | |
| schema = get_table_schema(table) | |
| df = get_data_df(schema) | |
| messages = format_prompt(df=df) | |
| tests = run_llm(messages) | |
| print(tests) | |
| stats = calcualte_stats(df) | |
| df_insights = stats['insights'] | |
| df_statistics = stats['overall_stats'] | |
| df_stat_1 = stats['stats_1'] | |
| df_stat_2 = stats['stats_2'] | |
| if isinstance(tests, Exception): | |
| tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
| return df.head(10), df_statistics, df_insights, df_stat_1, df_stat_2, tests, pd.DataFrame([]) | |
| tests_df = pd.DataFrame(tests) | |
| tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
| pandera_results = validate_pandera(tests, df) | |
| return df.head(10), df_statistics, df_insights, df_stat_1, df_stat_2, tests_df, pandera_results | |
| def user_results(table, text_query): | |
| schema = get_table_schema(table) | |
| df = get_data_df(schema) | |
| messages = format_user_prompt(df=df, user_description=text_query) | |
| tests = run_llm(messages) | |
| print(f'Generated Tests from user input: {tests}') | |
| if isinstance(tests, Exception): | |
| tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
| return tests, pd.DataFrame([]) | |
| tests_df = pd.DataFrame(tests) | |
| tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
| pandera_results = validate_pandera(tests, df) | |
| return tests_df, pandera_results | |
| # Custom CSS styling | |
| custom_css = """ | |
| print('Validated Tests with Pandera') | |
| .gradio-container { | |
| background-color: #f0f4f8; | |
| } | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #4a90e2 !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown(""" | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>Dataset Test Workflow</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Implement and Automate Data Validation Processes.</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
| tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
| with gr.Row(): | |
| generate_result = gr.Button("Validate Data", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.Tab("Description"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) | |
| with gr.Column(): | |
| describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) | |
| with gr.Tab("Alerts"): | |
| data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) | |
| with gr.Tab("Rules & Validations"): | |
| tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
| test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
| with gr.Tab("Data"): | |
| result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) | |
| with gr.Tab('Text to Validation'): | |
| with gr.Row(): | |
| query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pass | |
| with gr.Column(scale=1, min_width=50): | |
| user_generate_result = gr.Button("Validate Data", variant="primary" ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
| with gr.Column(): | |
| query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
| schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) | |
| generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) | |
| user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result]) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |