import os
import gradio as gr
import sqlparse
import requests
from time import sleep
import re
import platform
# Additional Firebase imports
import firebase_admin
from firebase_admin import credentials, firestore
import json
import base64

print(f"Running on {platform.system()}")

if platform.system() == "Windows" or platform.system() == "Darwin":
    from dotenv import load_dotenv
    load_dotenv()

quantized_model = "richardr1126/spider-skeleton-wizard-coder-ggml"
merged_model = "richardr1126/spider-skeleton-wizard-coder-merged"
initial_model = "WizardLM/WizardCoder-15B-V1.0"
lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
dataset = "richardr1126/spider-skeleton-context-instruct"

# Firebase code
# Initialize Firebase
base64_string = os.getenv('FIREBASE')
base64_bytes = base64_string.encode('utf-8')
json_bytes = base64.b64decode(base64_bytes)
json_data = json_bytes.decode('utf-8')

firebase_auth = json.loads(json_data)

# Load credentials and initialize Firestore
cred = credentials.Certificate(firebase_auth)
firebase_admin.initialize_app(cred)
db = firestore.client()

def log_message_to_firestore(input_message, db_info, temperature, response_text):
    doc_ref = db.collection('logs').document()
    log_data = {
        'timestamp': firestore.SERVER_TIMESTAMP,
        'temperature': temperature,
        'db_info': db_info,
        'input': input_message,
        'output': response_text,
    }
    doc_ref.set(log_data)

rated_outputs = set()  # set to store already rated outputs

def log_rating_to_firestore(input_message, db_info, temperature, response_text, rating):
    global rated_outputs
    output_id = f"{input_message} {db_info} {response_text} {temperature}"

    if output_id in rated_outputs:
        gr.Warning("You've already rated this output!")
        return
    if not input_message or not response_text or not rating:
        gr.Info("You haven't asked a question yet!")
        return
    
    rated_outputs.add(output_id)

    doc_ref = db.collection('ratings').document()
    log_data = {
        'timestamp': firestore.SERVER_TIMESTAMP,
        'temperature': temperature,
        'db_info': db_info,
        'input': input_message,
        'output': response_text,
        'rating': rating,
    }
    doc_ref.set(log_data)
    gr.Info("Thanks for your feedback!")
# End Firebase code

def format(text):
    # Split the text by "|", and get the last element in the list which should be the final query
    try:
        final_query = text.split("|")[1].strip()
    except Exception:
        final_query = text

    try:
        # Attempt to format SQL query using sqlparse
        formatted_query = sqlparse.format(final_query, reindent=True, keyword_case='upper')
    except Exception:
        # If formatting fails, use the original, unformatted query
        formatted_query = final_query

    # Convert SQL to markdown (not required, but just to show how to use the markdown module)
    final_query_markdown = f"{formatted_query}"

    return final_query_markdown

def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, stop_sequence="###", log=False):
    # Format the user's input message
    messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n"

    url = os.getenv("KOBOLDCPP_API_URL")
    stop_sequence = stop_sequence.split(",")
    stop = ["###"] + stop_sequence
    payload = {
        "prompt": messages,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "top_a": 0,
        "n": 1,
        "max_context_length": 2048,
        "max_length": 512,
        "rep_pen": repetition_penalty,
        "sampler_order": [6,0,1,3,4,2,5],
        "stop_sequence": stop,
    }
    headers = {
        "Content-Type": "application/json",
        "ngrok-skip-browser-warning": "1"  # added this line
    }

    for _ in range(3): # Try 3 times
        try:
            response = requests.post(url, json=payload, headers=headers)
            response_text = response.json()["results"][0]["text"]
            response_text = response_text.replace("\n", "").replace("\t", " ")
            if response_text and response_text[-1] == ".":
                response_text = response_text[:-1]

            output = format(response_text) if format_sql else response_text

            if log:
                # Log the request to Firestore
                log_message_to_firestore(input_message, db_info, temperature, output if format_sql else response_text)

            return output

            
        except Exception as e:
            print(f'Error occurred: {str(e)}')
            print('Waiting for 10 seconds before retrying...')
            gr.Warning("Error occurred, retrying, the sever may be down...")
            sleep(10)

# Gradio UI Code
with gr.Blocks(theme='gradio/soft') as demo:
    # Elements stack vertically by default just define elements in order you want them to stack
    header = gr.HTML("""
        <h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
        <h3 style="text-align: center">🕷️☠️🧙‍♂️ Generate SQL queries from Natural Language 🕷️☠️🧙‍♂️</h3>
        <div style="max-width: 450px; margin: auto; text-align: center">
            <p style="font-size: 12px; text-align: center">⚠️ Should take 30-60s to generate. Please rate the response, it helps a lot. If you get a blank output, the model server is currently down, please try again another time.</p>
        </div>
    """)

    output_box = gr.Code(label="Generated SQL", lines=2, interactive=False)

    with gr.Row():
        rate_up = gr.Button("👍", variant="secondary")
        rate_down = gr.Button("👎", variant="secondary")

    input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
    db_info = gr.Textbox(lines=4, placeholder='Make sure to place your tables information inside || for better results. Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
    format_sql = gr.Checkbox(label="Format SQL + Remove Skeleton", value=True, interactive=True)
    
    with gr.Row():
        run_button = gr.Button("Generate SQL", variant="primary")
        clear_button = gr.ClearButton(variant="secondary")

    with gr.Accordion("Options", open=False):
        temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
        top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
        top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
        repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
        stop_sequence = gr.Textbox(lines=1, value="Explanation,Note", label='Extra Stop Sequence')
    
    info = gr.HTML(f"""
        <p>🌐 Leveraging the <a href='https://huggingface.co/{quantized_model}'><strong>4-bit GGML version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
        <p>🔗 How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{lora_model}'><strong>{lora_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
        <p>📉 Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{lora_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
        <p>📊 All inputs/outputs are logged to Firebase to see how the model is doing. You can also leave a rating for each generated SQL the model produces, which gets sent to the database as well.</a></p>
    """)

    examples = gr.Examples([
        ["What is the average, minimum, and maximum age of all singers from France?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
        ["How many students have dogs?", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid | pets.pettype = 'Dog' |"],
    ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence], fn=generate, cache_examples=False if platform.system() == "Windows" or platform.system() == "Darwin" else True, outputs=output_box)

    with gr.Accordion("More Examples", open=False):
        examples = gr.Examples([
            ["What is the average weight of pets of all students?", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
            ["How many male singers performed in concerts in the year 2023?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["For students who have pets, how many pets does each student have? List their ids instead of names.", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
            ["Show location and name for all stadiums with a capacity between 5000 and 10000.", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["What are the number of concerts that occurred in the stadium with the largest capacity ?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["Which student has the oldest pet?", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
            ["List the names of all singers who performed in a concert with the theme 'Rock'", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["List all students who don't have pets.", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
        ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence], fn=generate, cache_examples=False, outputs=output_box)


    readme_content = requests.get(f"https://huggingface.co/{merged_model}/raw/main/README.md").text
    readme_content = re.sub('---.*?---', '', readme_content, flags=re.DOTALL) #Remove YAML front matter

    with gr.Accordion("📖 Model Readme", open=True):
        readme = gr.Markdown(
            readme_content,
        )
    
    with gr.Accordion("Disabled Options:", open=False):
        log = gr.Checkbox(label="Log to Firebase", value=True, interactive=False)
    
    # When the button is clicked, call the generate function, inputs are taken from the UI elements, outputs are sent to outputs elements
    run_button.click(fn=generate, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence, log], outputs=output_box, api_name="txt2sql")
    clear_button.add([input_text, db_info, output_box])

    # Firebase code - for rating the generated SQL (remove if you don't want to use Firebase)
    rate_up.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_up])
    rate_down.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_down])

demo.queue(concurrency_count=1, max_size=20).launch(debug=True)