File size: 3,752 Bytes
6842a4f
616d6ae
a3d0e1b
bb1df26
 
 
33e2541
bb1df26
 
33e2541
bb1df26
 
 
33e2541
bb1df26
 
 
 
 
 
 
 
 
 
 
 
 
 
33e2541
bb1df26
 
 
a3d0e1b
bb1df26
 
 
 
616d6ae
a3d0e1b
bb1df26
 
 
 
 
 
a3d0e1b
616d6ae
33e2541
616d6ae
a3d0e1b
bb1df26
 
a3d0e1b
bb1df26
 
 
 
 
 
 
 
 
a3d0e1b
bb1df26
 
 
 
 
 
a3d0e1b
bb1df26
 
 
 
 
 
 
 
 
 
 
a3d0e1b
33e2541
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load the models
tokenizer_sql = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model_sql = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")

tokenizer_question = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
model_question = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")

# Function to create the prompt for SQL model
def get_prompt_sql(tables, question):
    return f"""convert question and table into SQL query. tables: {tables}. question: {question}"""

# Function to prepare input data for the SQL model
def prepare_input_sql(question: str, tables: dict):
    tables = [f"""{table_name}({','.join(tables[table_name])})""" for table_name in tables]
    tables = ", ".join(tables)
    prompt = get_prompt_sql(tables, question)
    input_ids = tokenizer_sql(prompt, max_length=512, return_tensors="pt").input_ids
    return input_ids

# Inference function for the SQL model
def inference_sql(question: str, tables: dict) -> str:
    input_data = prepare_input_sql(question=question, tables=tables)
    input_data = input_data.to(model_sql.device)
    outputs = model_sql.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    return tokenizer_sql.decode(outputs[0], skip_special_tokens=True)

# Function to create the prompt for Question Generation model
def get_prompt_question(context):
    return f"generate a question from the following context: {context}"

# Function to prepare input data for the Question Generation model
def prepare_input_question(context: str):
    prompt = get_prompt_question(context)
    input_ids = tokenizer_question(prompt, max_length=512, return_tensors="pt").input_ids
    return input_ids

# Inference function for the Question Generation model
def inference_question(context: str) -> str:
    input_data = prepare_input_question(context)
    input_data = input_data.to(model_question.device)
    outputs = model_question.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    return tokenizer_question.decode(outputs[0], skip_special_tokens=True)

# Streamlit UI
def main():
    st.title("Multi-Model: Text to SQL and Question Generation")

    # Model selection
    model_choice = st.selectbox("Select a model", ["Text to SQL", "Question Generation"])

    # Input question and table schema for SQL model
    if model_choice == "Text to SQL":
        st.subheader("Text to SQL Model")
        question = st.text_area("Enter your question:")
        tables_input = st.text_area("Enter table schemas (in JSON format):", '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}')
        try:
            tables = eval(tables_input)  # Convert string to dict safely
        except:
            tables = {}

        if st.button("Generate SQL Query"):
            if question and tables:
                sql_query = inference_sql(question, tables)
                st.write(f"Generated SQL Query: {sql_query}")
            else:
                st.write("Please enter both a question and table schemas.")

    # Input context for Question Generation model
    elif model_choice == "Question Generation":
        st.subheader("Question Generation Model")
        context = st.text_area("Enter context:")
        
        if st.button("Generate Question"):
            if context:
                generated_question = inference_question(context)
                st.write(f"Generated Question: {generated_question}")
            else:
                st.write("Please enter context for question generation.")

if __name__ == "__main__":
    main()