from flask import Flask, request, jsonify, render_template_string
from vllm import LLM, SamplingParams
from langchain_community.cache import GPTCache
import torch

app = Flask(__name__)

# Verificar si hay una GPU disponible, si no usar la CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Inicializar los modelos con el dispositivo adecuado (GPU o CPU)
try:
    modelos = {
        "facebook/opt-125m": LLM(model="facebook/opt-125m", device=device),
        "llama-3.2-1B": LLM(model="Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device),
        "gpt2": LLM(model="gpt2", device=device)
    }
except KeyError as e:
    print(f"Error al inicializar el modelo con {device}: {e}")
    modelos = {}

# Verificar si los modelos fueron correctamente inicializados
if not modelos:
    print("Error: No se pudo inicializar ningún modelo.")
    exit(1)

# Configuración de caché para los modelos
caches = {
    nombre: GPTCache(modelo, max_size=1000)
    for nombre, modelo in modelos.items()
}

# Parámetros de muestreo para la generación de texto
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Código HTML para la documentación de la API
html_code_docs = """
<!DOCTYPE html>
<html>
<head>
    <title>Documentación de la API</title>
</head>
<body>
    <h1>API de Generación de Texto</h1>
    <h2>Endpoints</h2>
    <ul>
        <li>
            <h3>Generar texto</h3>
            <p>Método: POST</p>
            <p>Ruta: /generate</p>
            <p>Parámetros:</p>
            <ul>
                <li>prompts: Lista de prompts para generar texto</li>
                <li>modelo: Nombre del modelo a utilizar</li>
            </ul>
            <p>Ejemplo:</p>
            <pre>curl -X POST -H "Content-Type: application/json" -d '{"prompts": ["Hola, cómo estás?"], "modelo": "facebook/opt-125m"}' http://localhost:5000/generate</pre>
        </li>
        <li>
            <h3>Obtener lista de modelos</h3>
            <p>Método: GET</p>
            <p>Ruta: /modelos</p>
            <p>Ejemplo:</p>
            <pre>curl -X GET http://localhost:5000/modelos</pre>
        </li>
        <li>
            <h3>Chatbot</h3>
            <p>Método: POST</p>
            <p>Ruta: /chatbot</p>
            <p>Parámetros:</p>
            <ul>
                <li>mensaje: Mensaje para el chatbot</li>
                <li>modelo: Nombre del modelo a utilizar</li>
            </ul>
            <p>Ejemplo:</p>
            <pre>curl -X POST -H "Content-Type: application/json" -d '{"mensaje": "Hola, cómo estás?", "modelo": "facebook/opt-125m"}' http://localhost:5000/chatbot</pre>
        </li>
    </ul>
</body>
</html>
"""

# Código HTML para la interfaz del chatbot
html_code_chatbot = """
<!DOCTYPE html>
<html>
<head>
    <title>Chatbot</title>
</head>
<body>
    <h1>Chatbot</h1>
    <form id="chat-form">
        <input type="text" id="mensaje" placeholder="Escribe un mensaje">
        <button type="submit">Enviar</button>
    </form>
    <div id="respuestas"></div>

    <script>
        const form = document.getElementById('chat-form');
        const mensajeInput = document.getElementById('mensaje');
        const respuestasDiv = document.getElementById('respuestas');

        form.addEventListener('submit', (e) => {
            e.preventDefault();
            const mensaje = mensajeInput.value;
            fetch('/chatbot', {
                method: 'POST',
                headers: {
                    'Content-Type': 'application/json'
                },
                body: JSON.stringify({ mensaje })
            })
            .then((res) => res.json())
            .then((data) => {
                const respuesta = data.respuesta;
                const respuestaHTML = `<p>Tú: ${mensaje}</p><p>Chatbot: ${respuesta}</p>`;
                respuestasDiv.innerHTML += respuestaHTML;
                mensajeInput.value = '';
            });
        });
    </script>
</body>
</html>
"""

@app.route('/generate', methods=['POST'])
def generate():
    data = request.get_json()
    prompts = data.get('prompts', [])
    modelo_seleccionado = data.get('modelo', "facebook/opt-125m")

    if modelo_seleccionado not in modelos:
        return jsonify({"error": "Modelo no encontrado"}), 404

    outputs = caches[modelo_seleccionado].generate(prompts, sampling_params)

    results = []
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        results.append({
            'prompt': prompt,
            'generated_text': generated_text
        })

    return jsonify(results)

@app.route('/modelos', methods=['GET'])
def get_modelos():
    return jsonify({"modelos": list(modelos.keys())})

@app.route('/docs', methods=['GET'])
def docs():
    return render_template_string(html_code_docs)

@app.route('/chatbot', methods=['POST'])
def chatbot():
    data = request.get_json()
    mensaje = data.get('mensaje', '')
    modelo_seleccionado = data.get('modelo', "facebook/opt-125m")

    if modelo_seleccionado not in modelos:
        return jsonify({"error": "Modelo no encontrado"}), 404

    outputs = caches[modelo_seleccionado].generate([mensaje], sampling_params)

    respuesta = outputs[0].outputs[0].text

    return jsonify({"respuesta": respuesta})

@app.route('/chat', methods=['GET'])
def chat():
    return render_template_string(html_code_chatbot)

if __name__ == '__main__':
    # Asegurar que el servidor solo arranca si los modelos fueron inicializados correctamente
    if modelos:
        app.run(host='0.0.0.0', port=7860)
    else:
        print("Error: No se pudieron cargar los modelos. El servidor no se iniciará.")