Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load the models and tokenizers | |
question_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
question_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
sql_tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") | |
sql_model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") | |
# Function to generate a question based on a table schema | |
def generate_question(tables): | |
# Convert table schema to string | |
table_str = ", ".join([f"{table}: ({', '.join(cols)})" for table, cols in tables.items()]) | |
prompt = f"Generate a question based on the following table schema: {table_str}" | |
# Tokenize input and generate question | |
input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids | |
output = question_model.generate(input_ids, num_beams=5, max_length=50) | |
question = question_tokenizer.decode(output[0], skip_special_tokens=True) | |
return question | |
# Function to prepare input data for SQL generation | |
def prepare_sql_input(question, tables): | |
table_str = ", ".join([f"{table}({', '.join(cols)})" for table, cols in tables.items()]) | |
prompt = f"Convert the question and table schema into an SQL query. Tables: {table_str}. Question: {question}" | |
input_ids = sql_tokenizer(prompt, max_length=512, return_tensors="pt").input_ids | |
return input_ids | |
# Inference function for SQL generation | |
def generate_sql(question, tables): | |
input_data = prepare_sql_input(question, tables) | |
input_data = input_data.to(sql_model.device) | |
outputs = sql_model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) | |
sql_query = sql_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return sql_query | |
# Streamlit UI | |
def main(): | |
st.title("Multi-Model: Text to SQL and Question Generation") | |
# Input table schema | |
tables_input = st.text_area("Enter table schemas (in JSON format):", | |
'{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}') | |
try: | |
tables = eval(tables_input) # Convert string to dict safely | |
except: | |
tables = {} | |
# If tables are provided, generate a question | |
if tables: | |
generated_question = generate_question(tables) | |
st.write(f"Generated Question: {generated_question}") | |
# Input question manually if needed | |
question = st.text_area("Enter your question (optional):", generated_question if tables else "") | |
if st.button("Generate SQL Query"): | |
if question and tables: | |
sql_query = generate_sql(question, tables) | |
st.write(f"Generated SQL Query: {sql_query}") | |
else: | |
st.write("Please enter both a question and table schemas.") | |
if __name__ == "__main__": | |
main() | |