import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr

import numpy as np
import time
import os

#import pyodbc

#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server'  # This depends on the ODBC driver installed on your system

# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'

# Connect to the SQL Server
conn = pyodbc.connect(connection_string)

#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'

# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)

# Close the SQL connection
conn.close()
'''

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)


# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" 
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)

# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)

#sql_response = None
conversation_history = []

def chat(input, history=[]):

    #global sql_response
    # Check if the user input is a question
    #is_question = "?" in input

    '''
    if is_question: 
        sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
        sql_outputs = sql_model.generate(**sql_encoding)
        sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    else:
    '''
    
    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
    
    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
      
    # generate a response
    history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()

    # convert the tokens to text, and then split the responses into the right format
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list
    
    return response, history


def sqlquery(input):

    global conversation_history
    
    #input_text = " ".join(conversation_history) + " " + input
    sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
    sql_outputs = sql_model.generate(**sql_encoding)
    sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    
    #global conversation_history
    '''
    # Maintain the conversation history
    conversation_history.append("User: " + input + "\n")
    conversation_history.append("Bot: " + " ".join(sql_response) + "\n" )

    output = " ".join(conversation_history)
    return output
    '''

    conversation_history.append((input, sql_response))

    # Build conversation string
    conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
    return conversation

    '''
    html = "<div class='chatbot'>"
    for user_msg, resp_msg in conversation_history:
        html += f"<div class='user_msg'>{user_msg}</div>"
        html += f"<div class='resp_msg'>{resp_msg}</div>"
    html += "</div>"
    return html
    '''
    #return sql_response


chat_interface = gr.Interface(
    fn=chat,
    theme="default",
    css=".footer {display:none !important}",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)


sql_interface = gr.Interface(
    fn=sqlquery,
    theme="default",
    inputs=gr.Textbox(prompt="You:"),
    outputs=gr.Textbox(),
    #live=True,
    #capture_session=True,
    title="ST SQL Chat",
    description="Type your message in the box above, and the chatbot will respond.",
)


iface = gr.Interface(sqlquery, "text", "html", css="""
    .chatbox {display:flex;flex-direction:column}
    .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
    .user_msg {background-color:cornflowerblue;color:white;align-self:start}
    .resp_msg {background-color:lightgray;align-self:self-end}
""", allow_screenshot=False, allow_flagging=False)

combine_interface = gr.TabbedInterface(
    interface_list=[
        chat_interface,
        sql_interface
    ],
    tab_names=['Chatbot' ,'SQL Chat'],
)

if __name__ == '__main__':
    combine_interface.launch()
    #iface.launch(debug=True)