File size: 3,639 Bytes
ec9ef8b 4e86ef1 54210ca 5fb63f2 4e86ef1 b5d991e 4e86ef1 b5d991e 4e86ef1 f24bed6 4e86ef1 5fb63f2 4e86ef1 d002017 d087072 d002017 436b052 54210ca e030ac0 54210ca e030ac0 9a7d447 7f91b7d 54210ca 4e86ef1 5fb63f2 4e86ef1 9a7d447 5fb63f2 5d0a9ea 5fb63f2 f6a2d88 5fb63f2 5d0a9ea 04883bf 4e86ef1 54210ca 4e86ef1 54210ca e030ac0 4e86ef1 e030ac0 04883bf 4e86ef1 7e20218 4e86ef1 4f6e66f e030ac0 f24bed6 abad0fd 5fb63f2 7e20218 4004cf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch
#import pkg_resources
'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium"
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
# Load the SQL Model
#wikisql take longer to process
#model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
#model_name = "microsoft/tapex-base-finetuned-wikisql"
model_name = "microsoft/tapex-large-finetuned-wtq"
#model_name = "microsoft/tapex-base-finetuned-wtq"
tokenizer = TapexTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)
bot_input_ids = None
def chatbot_response(user_message):
# Generate chatbot response using the chatbot model
#inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
#outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
#response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
global bot_input_ids
# encode the new user input, add the eos_token and return a tensor in Pytorch
new_user_input_ids = chatbot_tokenizer.encode(user_message + chatbot_tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
if bot_input_ids is None:
bot_input_ids = new_user_input_ids
else:
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
# generated a response while limiting the total chat history to 1000 tokens,
chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
# pretty print last ouput tokens from bot
response = chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
return response
def sql_response(user_query):
#inputs = tokenizer.encode("User: " + user_query, return_tensors="pt")
inputs = user_query
encoding = tokenizer(table=table, query=inputs, return_tensors="pt")
outputs = model.generate(**encoding)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return response
# Define the chatbot and SQL execution interfaces using Gradio
chatbot_interface = gr.Interface(
fn=chatbot_response,
inputs=gr.Textbox(prompt="You:"),
outputs=gr.Textbox(),
live=True,
capture_session=True,
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
# Define the chatbot interface using Gradio
sql_interface = gr.Interface(
fn=sql_response,
inputs=gr.Textbox(prompt="Enter your SQL Qus:"),
outputs=gr.Textbox(),
live=True,
capture_session=True,
title="ST SQL Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
# Launch the Gradio interface
if __name__ == "__main__":
chatbot_interface.launch()
sql_interface.launch()
|