|
import os |
|
import sys |
|
import re |
|
import gradio as gr |
|
import json |
|
import tempfile |
|
import base64 |
|
import io |
|
from typing import List, Dict, Any, Optional, Tuple, Union |
|
import logging |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from plotly.subplots import make_subplots |
|
from flask import Flask, request, jsonify |
|
import uuid |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
flask_app = Flask(__name__) |
|
|
|
|
|
import os |
|
from dotenv import load_dotenv |
|
from sqlalchemy import create_engine, text |
|
from sqlalchemy.exc import SQLAlchemyError |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_community.agent_toolkits import create_sql_agent |
|
from langchain_community.utilities import SQLDatabase |
|
from langgraph.prebuilt import create_react_agent |
|
|
|
|
|
load_dotenv() |
|
|
|
def initialize_llm(): |
|
"""Inicializar el modelo LLM de Google Gemini.""" |
|
try: |
|
api_key = os.getenv('GOOGLE_API_KEY') |
|
if not api_key: |
|
return None, "No se encontró GOOGLE_API_KEY en las variables de entorno" |
|
|
|
llm = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash", |
|
google_api_key=api_key, |
|
temperature=0.1, |
|
convert_system_message_to_human=True |
|
) |
|
return llm, None |
|
except Exception as e: |
|
return None, str(e) |
|
|
|
def setup_database_connection(): |
|
"""Configurar la conexión a la base de datos.""" |
|
try: |
|
|
|
db_user = os.getenv('DB_USER') |
|
db_password = os.getenv('DB_PASSWORD') |
|
db_host = os.getenv('DB_HOST', 'localhost') |
|
db_name = os.getenv('DB_NAME') |
|
|
|
if not all([db_user, db_password, db_name]): |
|
return None, "Faltan variables de entorno para la conexión a la base de datos" |
|
|
|
|
|
connection_string = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}" |
|
|
|
|
|
engine = create_engine(connection_string) |
|
with engine.connect() as conn: |
|
conn.execute(text("SELECT 1")) |
|
|
|
return connection_string, None |
|
except Exception as e: |
|
return None, str(e) |
|
|
|
def create_agent(llm, connection_string): |
|
"""Crear el agente SQL.""" |
|
try: |
|
if not llm or not connection_string: |
|
return None, "LLM o conexión a base de datos no proporcionados" |
|
|
|
|
|
db = SQLDatabase.from_uri(connection_string) |
|
|
|
|
|
agent = create_sql_agent( |
|
llm=llm, |
|
db=db, |
|
agent_type="zero-shot-react-description", |
|
verbose=True, |
|
return_intermediate_steps=True |
|
) |
|
|
|
return agent, None |
|
except Exception as e: |
|
return None, str(e) |
|
|
|
def stream_agent_response(question: str, chat_history: List[List[str]] = None) -> Tuple[str, Optional[go.Figure]]: |
|
"""Procesar la respuesta del agente y generar visualizaciones si es necesario.""" |
|
try: |
|
|
|
llm, llm_error = initialize_llm() |
|
if llm_error: |
|
return f"Error al inicializar LLM: {llm_error}", None |
|
|
|
connection_string, db_error = setup_database_connection() |
|
if db_error: |
|
return f"Error de conexión a base de datos: {db_error}", None |
|
|
|
agent, agent_error = create_agent(llm, connection_string) |
|
if agent_error: |
|
return f"Error al crear el agente: {agent_error}", None |
|
|
|
|
|
response = agent.invoke({"input": question}) |
|
|
|
|
|
if hasattr(response, 'output'): |
|
response_text = response.output |
|
elif isinstance(response, dict) and 'output' in response: |
|
response_text = response['output'] |
|
else: |
|
response_text = str(response) |
|
|
|
|
|
chart_fig = None |
|
if hasattr(response, 'intermediate_steps'): |
|
for step in response.intermediate_steps: |
|
if len(step) > 1 and 'sql_query' in str(step[0]).lower(): |
|
|
|
try: |
|
query = str(step[0]).split('sql_query:')[1].split('\n')[0].strip() |
|
if 'SELECT' in query.upper(): |
|
df = pd.read_sql_query(query, create_engine(connection_string)) |
|
if len(df.columns) >= 2: |
|
fig = px.bar(df, x=df.columns[0], y=df.columns[1]) |
|
chart_fig = fig |
|
except: |
|
pass |
|
|
|
return response_text, chart_fig |
|
|
|
except Exception as e: |
|
return f"Error al procesar la solicitud: {str(e)}", None |
|
|
|
def create_ui(): |
|
"""Crear la interfaz de usuario de Gradio.""" |
|
with gr.Blocks(title="🤖 Asistente SQL con Gemini", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# 🤖 Asistente de Base de Datos SQL") |
|
gr.Markdown("Pregunta cualquier cosa sobre tu base de datos en lenguaje natural") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="Chat", |
|
type="messages", |
|
height=400 |
|
) |
|
|
|
with gr.Row(): |
|
question_input = gr.Textbox( |
|
label="Tu pregunta", |
|
placeholder="Ej: ¿Cuántos usuarios hay registrados?", |
|
lines=2, |
|
scale=4 |
|
) |
|
submit_button = gr.Button("📤 Enviar", scale=1) |
|
|
|
with gr.Column(scale=1): |
|
chart_display = gr.Plot( |
|
label="Visualización de datos", |
|
height=400 |
|
) |
|
|
|
|
|
streaming_output_display = gr.HTML(visible=False) |
|
|
|
return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display |
|
|
|
|
|
message_store: Dict[str, str] = {} |
|
|
|
@flask_app.route('/user_message', methods=['POST']) |
|
def handle_user_message(): |
|
try: |
|
data = request.get_json() |
|
if not data or 'message' not in data: |
|
return jsonify({'error': 'Se requiere el campo message'}), 400 |
|
|
|
user_message = data['message'] |
|
|
|
|
|
message_id = str(uuid.uuid4()) |
|
|
|
|
|
message_store[message_id] = user_message |
|
|
|
return jsonify({ |
|
'message_id': message_id, |
|
'status': 'success' |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
@flask_app.route('/ask', methods=['POST']) |
|
def handle_ask(): |
|
try: |
|
data = request.get_json() |
|
if not data or 'message_id' not in data: |
|
return jsonify({'error': 'Se requiere el campo message_id'}), 400 |
|
|
|
message_id = data['message_id'] |
|
|
|
|
|
if message_id not in message_store: |
|
return jsonify({'error': 'ID de mensaje no encontrado'}), 404 |
|
|
|
user_message = message_store[message_id] |
|
|
|
|
|
llm, llm_error = initialize_llm() |
|
if llm_error: |
|
return jsonify({'error': f'Error al inicializar LLM: {llm_error}'}), 500 |
|
|
|
connection_string, db_error = setup_database_connection() |
|
if db_error: |
|
return jsonify({'error': f'Error de conexión a la base de datos: {db_error}'}), 500 |
|
|
|
agent, agent_error = create_agent(llm, connection_string) |
|
if agent_error: |
|
return jsonify({'error': f'Error al crear el agente: {agent_error}'}), 500 |
|
|
|
|
|
response = agent.invoke({"input": user_message}) |
|
|
|
|
|
if hasattr(response, 'output') and response.output: |
|
response_text = response.output |
|
elif isinstance(response, str): |
|
response_text = response |
|
elif hasattr(response, 'get') and callable(response.get) and 'output' in response: |
|
response_text = response['output'] |
|
else: |
|
response_text = str(response) |
|
|
|
|
|
del message_store[message_id] |
|
|
|
return jsonify({ |
|
'response': response_text, |
|
'status': 'success' |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
|
|
def create_application(): |
|
"""Create and configure the Gradio application.""" |
|
|
|
demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui() |
|
|
|
|
|
if os.getenv('SPACE_ID'): |
|
demo = gr.mount_gradio_app( |
|
flask_app, |
|
demo, |
|
"/api" |
|
) |
|
|
|
def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]: |
|
"""Add user message to chat history (messages format) and clear input.""" |
|
if not user_input.strip(): |
|
return "", chat_history |
|
|
|
logger.info(f"User message: {user_input}") |
|
|
|
if chat_history is None: |
|
chat_history = [] |
|
|
|
|
|
chat_history.append({"role": "user", "content": user_input}) |
|
|
|
return "", chat_history |
|
|
|
def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]: |
|
"""Generate bot response for messages-format chat history and return optional chart figure.""" |
|
if not chat_history: |
|
return chat_history, None |
|
|
|
|
|
last = chat_history[-1] |
|
if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"): |
|
return chat_history, None |
|
|
|
try: |
|
question = last["content"] |
|
logger.info(f"Processing question: {question}") |
|
|
|
|
|
pair_history: List[List[str]] = [] |
|
i = 0 |
|
while i < len(chat_history) - 1: |
|
m1 = chat_history[i] |
|
m2 = chat_history[i + 1] if i + 1 < len(chat_history) else None |
|
if ( |
|
isinstance(m1, dict) |
|
and m1.get("role") == "user" |
|
and isinstance(m2, dict) |
|
and m2.get("role") == "assistant" |
|
): |
|
pair_history.append([m1.get("content", ""), m2.get("content", "")]) |
|
i += 2 |
|
else: |
|
i += 1 |
|
|
|
|
|
assistant_message, chart_fig = stream_agent_response(question, pair_history) |
|
|
|
|
|
chat_history.append({"role": "assistant", "content": assistant_message}) |
|
|
|
logger.info("Response generation complete") |
|
return chat_history, chart_fig |
|
|
|
except Exception as e: |
|
error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```" |
|
logger.error(error_msg, exc_info=True) |
|
|
|
chat_history.append({"role": "assistant", "content": error_msg}) |
|
return chat_history, None |
|
|
|
|
|
with demo: |
|
|
|
msg_submit = question_input.submit( |
|
fn=user_message, |
|
inputs=[question_input, chatbot], |
|
outputs=[question_input, chatbot], |
|
queue=False |
|
).then( |
|
fn=bot_response, |
|
inputs=[chatbot], |
|
outputs=[chatbot, chart_display], |
|
api_name="ask" |
|
) |
|
|
|
|
|
btn_click = submit_button.click( |
|
fn=user_message, |
|
inputs=[question_input, chatbot], |
|
outputs=[question_input, chatbot], |
|
queue=False |
|
).then( |
|
fn=bot_response, |
|
inputs=[chatbot], |
|
outputs=[chatbot, chart_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
demo = create_application() |
|
|
|
|
|
def get_app(): |
|
"""Obtiene la instancia de la aplicación Gradio para Hugging Face Spaces.""" |
|
|
|
if os.getenv('SPACE_ID'): |
|
|
|
demo.title = "🤖 Asistente de Base de Datos SQL (Demo)" |
|
demo.description = """ |
|
Este es un demo del asistente de base de datos SQL. |
|
Para usar la versión completa con conexión a base de datos, clona este espacio y configura las variables de entorno. |
|
""" |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if os.environ.get('RUN_FLASK', 'false').lower() == 'true': |
|
|
|
port = int(os.environ.get('PORT', 5000)) |
|
flask_app.run(host='0.0.0.0', port=port) |
|
else: |
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
debug=True, |
|
share=False |
|
) |
|
|