Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load the models | |
tokenizer_sql = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") | |
model_sql = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") | |
tokenizer_question = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
model_question = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
# Function to create the prompt for SQL model | |
def get_prompt_sql(tables, question): | |
return f"""convert question and table into SQL query. tables: {tables}. question: {question}""" | |
# Function to prepare input data for the SQL model | |
def prepare_input_sql(question: str, tables: dict): | |
tables = [f"""{table_name}({','.join(tables[table_name])})""" for table_name in tables] | |
tables = ", ".join(tables) | |
prompt = get_prompt_sql(tables, question) | |
input_ids = tokenizer_sql(prompt, max_length=512, return_tensors="pt").input_ids | |
return input_ids | |
# Inference function for the SQL model | |
def inference_sql(question: str, tables: dict) -> str: | |
input_data = prepare_input_sql(question=question, tables=tables) | |
input_data = input_data.to(model_sql.device) | |
outputs = model_sql.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) | |
return tokenizer_sql.decode(outputs[0], skip_special_tokens=True) | |
# Function to create the prompt for Question Generation model | |
def get_prompt_question(context): | |
return f"generate a question from the following context: {context}" | |
# Function to prepare input data for the Question Generation model | |
def prepare_input_question(context: str): | |
prompt = get_prompt_question(context) | |
input_ids = tokenizer_question(prompt, max_length=512, return_tensors="pt").input_ids | |
return input_ids | |
# Inference function for the Question Generation model | |
def inference_question(context: str) -> str: | |
input_data = prepare_input_question(context) | |
input_data = input_data.to(model_question.device) | |
outputs = model_question.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) | |
return tokenizer_question.decode(outputs[0], skip_special_tokens=True) | |
# Streamlit UI | |
def main(): | |
st.title("Multi-Model: Text to SQL and Question Generation") | |
# Model selection | |
model_choice = st.selectbox("Select a model", ["Text to SQL", "Question Generation"]) | |
# Input question and table schema for SQL model | |
if model_choice == "Text to SQL": | |
st.subheader("Text to SQL Model") | |
question = st.text_area("Enter your question:") | |
tables_input = st.text_area("Enter table schemas (in JSON format):", '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}') | |
try: | |
tables = eval(tables_input) # Convert string to dict safely | |
except: | |
tables = {} | |
if st.button("Generate SQL Query"): | |
if question and tables: | |
sql_query = inference_sql(question, tables) | |
st.write(f"Generated SQL Query: {sql_query}") | |
else: | |
st.write("Please enter both a question and table schemas.") | |
# Input context for Question Generation model | |
elif model_choice == "Question Generation": | |
st.subheader("Question Generation Model") | |
context = st.text_area("Enter context:") | |
if st.button("Generate Question"): | |
if context: | |
generated_question = inference_question(context) | |
st.write(f"Generated Question: {generated_question}") | |
else: | |
st.write("Please enter context for question generation.") | |
if __name__ == "__main__": | |
main() | |