Spaces:
Sleeping
Sleeping
File size: 3,752 Bytes
6842a4f 616d6ae a3d0e1b bb1df26 33e2541 bb1df26 33e2541 bb1df26 33e2541 bb1df26 33e2541 bb1df26 a3d0e1b bb1df26 616d6ae a3d0e1b bb1df26 a3d0e1b 616d6ae 33e2541 616d6ae a3d0e1b bb1df26 a3d0e1b bb1df26 a3d0e1b bb1df26 a3d0e1b bb1df26 a3d0e1b 33e2541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models
tokenizer_sql = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model_sql = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
tokenizer_question = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
model_question = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
# Function to create the prompt for SQL model
def get_prompt_sql(tables, question):
return f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
# Function to prepare input data for the SQL model
def prepare_input_sql(question: str, tables: dict):
tables = [f"""{table_name}({','.join(tables[table_name])})""" for table_name in tables]
tables = ", ".join(tables)
prompt = get_prompt_sql(tables, question)
input_ids = tokenizer_sql(prompt, max_length=512, return_tensors="pt").input_ids
return input_ids
# Inference function for the SQL model
def inference_sql(question: str, tables: dict) -> str:
input_data = prepare_input_sql(question=question, tables=tables)
input_data = input_data.to(model_sql.device)
outputs = model_sql.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
return tokenizer_sql.decode(outputs[0], skip_special_tokens=True)
# Function to create the prompt for Question Generation model
def get_prompt_question(context):
return f"generate a question from the following context: {context}"
# Function to prepare input data for the Question Generation model
def prepare_input_question(context: str):
prompt = get_prompt_question(context)
input_ids = tokenizer_question(prompt, max_length=512, return_tensors="pt").input_ids
return input_ids
# Inference function for the Question Generation model
def inference_question(context: str) -> str:
input_data = prepare_input_question(context)
input_data = input_data.to(model_question.device)
outputs = model_question.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
return tokenizer_question.decode(outputs[0], skip_special_tokens=True)
# Streamlit UI
def main():
st.title("Multi-Model: Text to SQL and Question Generation")
# Model selection
model_choice = st.selectbox("Select a model", ["Text to SQL", "Question Generation"])
# Input question and table schema for SQL model
if model_choice == "Text to SQL":
st.subheader("Text to SQL Model")
question = st.text_area("Enter your question:")
tables_input = st.text_area("Enter table schemas (in JSON format):", '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}')
try:
tables = eval(tables_input) # Convert string to dict safely
except:
tables = {}
if st.button("Generate SQL Query"):
if question and tables:
sql_query = inference_sql(question, tables)
st.write(f"Generated SQL Query: {sql_query}")
else:
st.write("Please enter both a question and table schemas.")
# Input context for Question Generation model
elif model_choice == "Question Generation":
st.subheader("Question Generation Model")
context = st.text_area("Enter context:")
if st.button("Generate Question"):
if context:
generated_question = inference_question(context)
st.write(f"Generated Question: {generated_question}")
else:
st.write("Please enter context for question generation.")
if __name__ == "__main__":
main()
|