Spaces:
Sleeping
Sleeping
import os | |
import json | |
import duckdb | |
import gradio as gr | |
import pandas as pd | |
import pandera as pa | |
from pandera import Column | |
import ydata_profiling as pp | |
from huggingface_hub import InferenceClient | |
from prompt import PROMPT_PANDERA | |
# Height of the Tabs Text Area | |
TAB_LINES = 8 | |
# Load Token | |
md_token = os.getenv('MD_TOKEN') | |
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN') | |
INPUT_PROMPT = ''' | |
Here is the frist few samples of data: | |
<Sample Data> | |
{data} | |
</Sample Data<> | |
''' | |
print('Connecting to DB...') | |
# Connect to DB | |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") | |
# Get Databases | |
def get_schemas(): | |
schemas = conn.execute(""" | |
SELECT DISTINCT schema_name | |
FROM information_schema.schemata | |
WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
""").fetchall() | |
return [item[0] for item in schemas] | |
# Get Tables | |
def get_tables_names(schema_name): | |
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
return [table[0] for table in tables] | |
# Update Tables | |
def update_table_names(schema_name): | |
tables = get_tables_names(schema_name) | |
return gr.update(choices=tables) | |
def get_data_df(schema): | |
print('Getting Dataframe from the Database') | |
return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df() | |
def run_llm(df): | |
messages=[ | |
{"role": "system", "content": PROMPT_PANDERA}, | |
{"role": "user", "content": INPUT_PROMPT.format(data=df.head().to_json(orient='records'))}, | |
] | |
try: | |
response = client.chat_completion(messages, max_tokens=1024) | |
print(response.choices[0].message.content) | |
tests = json.loads(response.choices[0].message.content) | |
except Exception as e: | |
return e | |
return tests | |
# Get Schema | |
def get_table_schema(table): | |
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
ddl_create = result.iloc[0,0] | |
parent_database = result.iloc[0,1] | |
schema_name = result.iloc[0,2] | |
full_path = f"{parent_database}.{schema_name}.{table}" | |
if schema_name != "main": | |
old_path = f"{schema_name}.{table}" | |
else: | |
old_path = table | |
ddl_create = ddl_create.replace(old_path, full_path) | |
return full_path | |
def describe(df): | |
numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index() | |
numerical_info.rename(columns={'index': 'column'}, inplace=True) | |
categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index() | |
categorical_info.rename(columns={'index': 'column'}, inplace=True) | |
return numerical_info, categorical_info | |
def validate_pandera(tests, df): | |
validation_results = [] | |
# Loop through each test rule and validate each column separately | |
for test in tests: | |
column_name = test['column_name'] | |
rule = eval(test['pandera_rule']) # Evaluate the Pandera column rule | |
try: | |
# Apply the rule to the column and validate | |
validated_column = rule(df[[column_name]]) # Validate the specific column | |
validation_results.append({ | |
"Columns": column_name, | |
"Result": "✅ Pass" | |
}) | |
except Exception as e: | |
# If validation fails, catch the exception and mark the column as 'Fail' | |
validation_results.append({ | |
"Columns": column_name, | |
"Result": f"❌ Fail - {str(e)}" | |
}) | |
return pd.DataFrame(validation_results) | |
def statistics(df): | |
profile = pp.ProfileReport(df) | |
report_dict = profile.get_description() | |
description, alerts = report_dict.table, report_dict.alerts | |
# Statistics | |
mapping = { | |
'n': 'Number of observations', | |
'n_var': 'Number of variables', | |
'n_cells_missing': 'Number of cells missing', | |
'n_vars_with_missing': 'Number of columns with missing data', | |
'n_vars_all_missing': 'Columns with all missing data', | |
'p_cells_missing': 'Missing cells (%)', | |
'n_duplicates': 'Duplicated rows', | |
'p_duplicates': 'Duplicated rows (%)', | |
} | |
updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'} | |
# Add flattened types information | |
if 'Text' in description.get('types', {}): | |
updated_data['Number of text columns'] = description['types']['Text'] | |
if 'Categorical' in description.get('types', {}): | |
updated_data['Number of categorical columns'] = description['types']['Categorical'] | |
if 'Numeric' in description.get('types', {}): | |
updated_data['Number of numeric columns'] = description['types']['Numeric'] | |
if 'DateTime' in description.get('types', {}): | |
updated_data['Number of datetime columns'] = description['types']['DateTime'] | |
df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value']) | |
df_statistics['Value'] = df_statistics['Value'].astype(int) | |
# Alerts | |
alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts] | |
df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category']) | |
return df_statistics, df_alerts | |
# Main Function | |
def main(table): | |
schema = get_table_schema(table) | |
df = get_data_df(schema) | |
df_statistics, df_alerts = statistics(df) | |
describe_cat, describe_num = describe(df) | |
tests = run_llm(df) | |
print(tests) | |
if isinstance(tests, Exception): | |
tests = pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {tests}"}]) | |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([]) | |
tests_df = pd.DataFrame(tests) | |
tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
pandera_results = validate_pandera(tests, df) | |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results | |
# Custom CSS styling | |
custom_css = """ | |
.gradio-container { | |
background-color: #f0f4f8; | |
} | |
.logo { | |
max-width: 200px; | |
margin: 20px auto; | |
display: block; | |
} | |
.gr-button { | |
background-color: #4a90e2 !important; | |
} | |
.gr-button:hover { | |
background-color: #3a7bc8 !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
gr.Markdown(""" | |
<div style='text-align: center;'> | |
<strong style='font-size: 36px;'>Dataset Test Workflow</strong> | |
<br> | |
<span style='font-size: 20px;'>Implement and Automate Data Validation Processes.</span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
with gr.Row(): | |
generate_query_button = gr.Button("Validate Data", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.Tab("Description"): | |
with gr.Row(): | |
with gr.Column(min_width=220): | |
data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) | |
with gr.Row(): | |
with gr.Column(min_width=320): | |
describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) | |
with gr.Column(min_width=320): | |
describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) | |
with gr.Tab("Alerts"): | |
data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) | |
with gr.Tab("Rules & Validations"): | |
tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
with gr.Tab("Data"): | |
result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) | |
schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) | |
generate_query_button.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |