import os import json import duckdb import gradio as gr import pandas as pd import pandera as pa from pandera import Column import ydata_profiling as pp from huggingface_hub import InferenceClient from prompt import PROMPT_PANDERA # Height of the Tabs Text Area TAB_LINES = 8 # Load Token md_token = os.getenv('MD_TOKEN') os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN') INPUT_PROMPT = ''' Here is the frist few samples of data: {data} ''' print('Connecting to DB...') # Connect to DB conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") # Get Databases def get_schemas(): schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] # Get Tables def get_tables_names(schema_name): tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() return [table[0] for table in tables] # Update Tables def update_table_names(schema_name): tables = get_tables_names(schema_name) return gr.update(choices=tables) def get_data_df(schema): print('Getting Dataframe from the Database') return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df() def run_llm(df): messages=[ {"role": "system", "content": PROMPT_PANDERA}, {"role": "user", "content": INPUT_PROMPT.format(data=df.head().to_json(orient='records'))}, ] try: response = client.chat_completion(messages, max_tokens=1024) print(response.choices[0].message.content) tests = json.loads(response.choices[0].message.content) except Exception as e: return e return tests # Get Schema def get_table_schema(table): result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() ddl_create = result.iloc[0,0] parent_database = result.iloc[0,1] schema_name = result.iloc[0,2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return full_path def describe(df): numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index() numerical_info.rename(columns={'index': 'column'}, inplace=True) categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index() categorical_info.rename(columns={'index': 'column'}, inplace=True) return numerical_info, categorical_info def validate_pandera(tests, df): validation_results = [] for test in tests: column_name = test['column_name'] try: rule = eval(test['pandera_rule']) validated_column = rule(df[[column_name]]) validation_results.append({ "Columns": column_name, "Result": "✅ Pass" }) except Exception as e: validation_results.append({ "Columns": column_name, "Result": f"❌ Fail - {str(e)}" }) return pd.DataFrame(validation_results) def statistics(df): profile = pp.ProfileReport(df) report_dict = profile.get_description() description, alerts = report_dict.table, report_dict.alerts # Statistics mapping = { 'n': 'Number of observations', 'n_var': 'Number of variables', 'n_cells_missing': 'Number of cells missing', 'n_vars_with_missing': 'Number of columns with missing data', 'n_vars_all_missing': 'Columns with all missing data', 'p_cells_missing': 'Missing cells (%)', 'n_duplicates': 'Duplicated rows', 'p_duplicates': 'Duplicated rows (%)', } updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'} # Add flattened types information if 'Text' in description.get('types', {}): updated_data['Number of text columns'] = description['types']['Text'] if 'Categorical' in description.get('types', {}): updated_data['Number of categorical columns'] = description['types']['Categorical'] if 'Numeric' in description.get('types', {}): updated_data['Number of numeric columns'] = description['types']['Numeric'] if 'DateTime' in description.get('types', {}): updated_data['Number of datetime columns'] = description['types']['DateTime'] df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value']) df_statistics['Value'] = df_statistics['Value'].astype(int) # Alerts alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts] df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category']) return df_statistics, df_alerts # Main Function def main(table): schema = get_table_schema(table) df = get_data_df(schema) df_statistics, df_alerts = statistics(df) describe_cat, describe_num = describe(df) tests = run_llm(df) print(tests) if isinstance(tests, Exception): tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([]) tests_df = pd.DataFrame(tests) tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) pandera_results = validate_pandera(tests, df) return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results # Custom CSS styling custom_css = """ .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""
Dataset Test Workflow
Implement and Automate Data Validation Processes.
""") with gr.Row(): with gr.Column(scale=1): schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) with gr.Row(): generate_query_button = gr.Button("Validate Data", variant="primary") with gr.Column(scale=2): with gr.Tabs(): with gr.Tab("Description"): with gr.Row(): with gr.Column(): data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) with gr.Row(): with gr.Column(): describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) with gr.Column(): describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) with gr.Tab("Alerts"): data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) with gr.Tab("Rules & Validations"): tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) with gr.Tab("Data"): result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) generate_query_button.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) if __name__ == "__main__": demo.launch(debug=True)