Spaces:
Sleeping
Sleeping
import os | |
import json | |
import duckdb | |
import gradio as gr | |
import pandas as pd | |
import pandera as pa | |
from pandera import Column | |
import ydata_profiling as pp | |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
from langsmith import traceable | |
from langchain import hub | |
import warnings | |
import dlt | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# Height of the Tabs Text Area | |
TAB_LINES = 8 | |
#----------CONNECT TO DATABASE---------- | |
md_token = os.getenv('MD_TOKEN') | |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
#--------------------------------------- | |
#-------LOAD HUGGINGFACE------- | |
models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", | |
"meta-llama/Llama-3.1-70B-Instruct"] | |
model_loaded = False | |
for model in models: | |
try: | |
endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192) | |
info = endpoint.client.get_endpoint_info() | |
model_loaded = True | |
break | |
except Exception as e: | |
print(f"Error for model {model}: {e}") | |
continue | |
llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=8192) | |
#--------------------------------------- | |
#-----LOAD PROMPT FROM LANCHAIN HUB----- | |
prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow") | |
prompt_user_input = hub.pull("usergenerate-rules-testworkflow") | |
#--------------ALL UTILS---------------- | |
# Get Databases | |
def get_schemas(): | |
schemas = conn.execute(""" | |
SELECT DISTINCT schema_name | |
FROM information_schema.schemata | |
WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
""").fetchall() | |
return [item[0] for item in schemas] | |
# Get Tables | |
def get_tables_names(schema_name): | |
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
return [table[0] for table in tables] | |
# Update Tables | |
def update_table_names(schema_name): | |
tables = get_tables_names(schema_name) | |
return gr.update(choices=tables) | |
# def get_data_df(schema): | |
# print('Getting Dataframe from the Database') | |
# return conn.sql(f"SELECT * FROM {schema} LIMIT 1000") | |
def fetch_data(schema): | |
result = conn.sql(f"SELECT * FROM {schema} LIMIT 1000") | |
while True: | |
chunk_df = result.fetch_df_chunk(2) | |
if chunk_df is None or len(chunk_df) == 0: | |
break | |
else: | |
yield chunk_df | |
def create_pipeline(schema): | |
dataset_name = schema.split('.')[1] | |
print("Dataset Name: ", dataset_name) | |
table_name = schema.split('.')[2] | |
print("Table Name: ", table_name) | |
pipeline =dlt.pipeline( | |
pipeline_name='duckdb_pipeline', | |
destination='duckdb', | |
dataset_name= dataset_name, | |
) | |
load_info = pipeline.run(fetch_data(schema), table_name = table_name, | |
write_disposition = "replace") | |
print(load_info) | |
return dataset_name + "." + table_name | |
def load_pipeline(table_name): | |
_conn = duckdb.connect("duckdb_pipeline.duckdb") | |
return _conn, _conn.sql(f"SELECT * FROM {table_name} LIMIT 1000").df() | |
def df_summary(df): | |
summary = [] | |
for column in df.columns: | |
if pd.api.types.is_numeric_dtype(df[column]): | |
summary.append({ | |
"column": column, | |
"max": df[column].max(), | |
"min": df[column].min(), | |
"count": df[column].count(), | |
"nunique": df[column].nunique(), | |
"dtype": str(df[column].dtype), | |
"top": None | |
}) | |
elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): | |
top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None | |
summary.append({ | |
"column": column, | |
"max": None, | |
"min": None, | |
"count": df[column].count(), | |
"nunique": df[column].nunique(), | |
"dtype": str(df[column].dtype), | |
"top": top_value | |
}) | |
summary_df = pd.DataFrame(summary) | |
return summary_df.reset_index(drop=True) | |
def format_prompt(df): | |
summary = df_summary(df) | |
return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'), | |
summary=summary.to_json(orient='records')) | |
def format_user_prompt(df): | |
return prompt_user_input.format_prompt(data=df.head().to_json(orient='records')) | |
def process_inputs(inputs) : | |
return {'input_query': inputs['messages'].to_messages()[1]} | |
def run_llm(messages): | |
try: | |
response = llm.invoke(messages) | |
print(response.content.replace("```", "'''").replace("json", "")) | |
tests = json.loads(response.content.replace("```", "").replace("json", "")) | |
except Exception as e: | |
return e | |
return tests | |
# Get Schema | |
def get_table_schema(table): | |
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
ddl_create = result.iloc[0,0] | |
parent_database = result.iloc[0,1] | |
schema_name = result.iloc[0,2] | |
full_path = f"{parent_database}.{schema_name}.{table}" | |
if schema_name != "main": | |
old_path = f"{schema_name}.{table}" | |
else: | |
old_path = table | |
ddl_create = ddl_create.replace(old_path, full_path) | |
return full_path | |
def describe(df): | |
numerical_info = pd.DataFrame() | |
categorical_info = pd.DataFrame() | |
if len(df.select_dtypes(include=['number']).columns) >= 1: | |
numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index() | |
numerical_info.rename(columns={'index': 'column'}, inplace=True) | |
if len(df.select_dtypes(include=['object']).columns) >= 1: | |
categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index() | |
categorical_info.rename(columns={'index': 'column'}, inplace=True) | |
return numerical_info, categorical_info | |
def validate_pandera(tests, df): | |
validation_results = [] | |
for test in tests: | |
column_name = test['column_name'] | |
try: | |
rule = eval(test['pandera_rule']) | |
validated_column = rule(df[[column_name]]) | |
validation_results.append({ | |
"Columns": column_name, | |
"Result": "✅ Pass" | |
}) | |
except Exception as e: | |
validation_results.append({ | |
"Columns": column_name, | |
"Result": f"❌ Fail - {str(e)}" | |
}) | |
return pd.DataFrame(validation_results) | |
def statistics(df): | |
profile = pp.ProfileReport(df) | |
report_dict = profile.get_description() | |
description, alerts = report_dict.table, report_dict.alerts | |
# Statistics | |
mapping = { | |
'n': 'Number of observations', | |
'n_var': 'Number of variables', | |
'n_cells_missing': 'Number of cells missing', | |
'n_vars_with_missing': 'Number of columns with missing data', | |
'n_vars_all_missing': 'Columns with all missing data', | |
'p_cells_missing': 'Missing cells (%)', | |
'n_duplicates': 'Duplicated rows', | |
'p_duplicates': 'Duplicated rows (%)', | |
} | |
updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'} | |
# Add flattened types information | |
if 'Text' in description.get('types', {}): | |
updated_data['Number of text columns'] = description['types']['Text'] | |
if 'Categorical' in description.get('types', {}): | |
updated_data['Number of categorical columns'] = description['types']['Categorical'] | |
if 'Numeric' in description.get('types', {}): | |
updated_data['Number of numeric columns'] = description['types']['Numeric'] | |
if 'DateTime' in description.get('types', {}): | |
updated_data['Number of datetime columns'] = description['types']['DateTime'] | |
df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value']) | |
df_statistics['Value'] = df_statistics['Value'].astype(int) | |
# Alerts | |
alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts] | |
df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category']) | |
return df_statistics, df_alerts | |
#--------------------------------------- | |
# Main Function | |
def main(table): | |
schema = get_table_schema(table) | |
# Create dlt pipeline | |
table_name = create_pipeline(schema) | |
# Load dlt pipeline | |
connection, df = load_pipeline(table_name) | |
# df = get_data_df(schema) | |
df_statistics, df_alerts = statistics(df) | |
describe_num, describe_cat = describe(df) | |
messages = format_prompt(df=df) | |
tests = run_llm(messages) | |
if isinstance(tests, Exception): | |
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([]) | |
tests_df = pd.DataFrame(tests) | |
tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
pandera_results = validate_pandera(tests, df) | |
connection.close() | |
return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results | |
def user_results(table, text_query): | |
schema = get_table_schema(table) | |
# Create dlt pipeline | |
table_name = create_pipeline(schema) | |
# Load dlt pipeline | |
connection, df = load_pipeline(table_name) | |
messages = format_user_prompt(df=df, user_description=text_query) | |
print(f'Generated Tests from user input: {tests}') | |
if isinstance(tests, Exception): | |
tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
return tests, pd.DataFrame([]) | |
tests_df = pd.DataFrame(tests) | |
tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
pandera_results = validate_pandera(tests, df) | |
connection.close() | |
return tests_df, pandera_results | |
# Custom CSS styling | |
custom_css = """ | |
print('Validated Tests with Pandera') | |
.gradio-container { | |
background-color: #f0f4f8; | |
} | |
.logo { | |
max-width: 200px; | |
margin: 20px auto; | |
display: block; | |
} | |
.gr-button { | |
background-color: #4a90e2 !important; | |
} | |
.gr-button:hover { | |
background-color: #3a7bc8 !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
gr.Markdown(""" | |
<div style='text-align: center;'> | |
<strong style='font-size: 36px;'>Dataset Test Workflow</strong> | |
<br> | |
<span style='font-size: 20px;'>Implement and Automate Data Validation Processes.</span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
with gr.Row(): | |
generate_result = gr.Button("Validate Data", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.Tab("Description"): | |
with gr.Row(): | |
with gr.Column(): | |
data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) | |
with gr.Column(): | |
describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) | |
with gr.Tab("Alerts"): | |
data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) | |
with gr.Tab("Rules & Validations"): | |
tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
with gr.Tab("Data"): | |
result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) | |
with gr.Tab('Text to Validation'): | |
with gr.Row(): | |
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.") | |
with gr.Row(): | |
with gr.Column(): | |
pass | |
with gr.Column(scale=1, min_width=50): | |
user_generate_result = gr.Button("Validate Data", variant="primary" ) | |
with gr.Row(): | |
with gr.Column(): | |
query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
with gr.Column(): | |
query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) | |
generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) | |
user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |