GlastonR commited on
Commit
616d6ae
·
verified ·
1 Parent(s): 33e2541

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -55
app.py CHANGED
@@ -1,72 +1,67 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
2
  import streamlit as st
 
3
 
4
- @st.cache_resource
5
- def load_models():
6
- question_model_name = "mrm8488/t5-base-finetuned-question-generation-ap"
7
- recipe_model_name = "flax-community/t5-recipe-generation"
8
- instruct_model_name = "norallm/normistral-7b-warm-instruct"
9
 
10
- # Load T5-based models for question generation and recipe generation
11
- question_model = AutoModelForSeq2SeqLM.from_pretrained(question_model_name)
12
- question_tokenizer = AutoTokenizer.from_pretrained(question_model_name)
13
 
14
- recipe_model = AutoModelForSeq2SeqLM.from_pretrained(recipe_model_name)
15
- recipe_tokenizer = AutoTokenizer.from_pretrained(recipe_model_name)
 
 
 
16
 
17
- # Load the instruction model as a causal language model
18
- instruct_model = AutoModelForCausalLM.from_pretrained(instruct_model_name)
19
- instruct_tokenizer = AutoTokenizer.from_pretrained(instruct_model_name)
 
 
20
 
21
- return (question_model, question_tokenizer), (recipe_model, recipe_tokenizer), (instruct_model, instruct_tokenizer)
 
 
 
22
 
23
- # Function to generate questions using the question model
24
- def generate_question(text, model, tokenizer):
25
- input_text = f"generate question: {text}"
26
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
27
- outputs = model.generate(input_ids)
28
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
29
 
30
- # Function to generate recipes using the recipe model
31
- def generate_recipe(ingredients, model, tokenizer):
32
- input_text = f"generate recipe: {ingredients}"
33
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
34
- outputs = model.generate(input_ids)
35
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
36
 
37
- # Function to generate instructions using the instruction model
38
- def generate_instruction(prompt, model, tokenizer):
39
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
40
- outputs = model.generate(input_ids)
41
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
42
-
43
- # Streamlit Interface
44
  def main():
45
- st.title("Multi-Model Application")
46
-
47
- # Load all models
48
- (question_model, question_tokenizer), (recipe_model, recipe_tokenizer), (instruct_model, instruct_tokenizer) = load_models()
49
 
50
- # Tabs for different functionalities
51
- tab = st.selectbox("Choose task", ["Question Generation", "Recipe Generation", "Instruction Following"])
 
 
 
 
 
52
 
53
- if tab == "Question Generation":
54
- passage = st.text_area("Enter a passage for question generation")
55
- if st.button("Generate Question"):
56
- question = generate_question(passage, question_model, question_tokenizer)
57
- st.write("Generated Question:", question)
58
 
59
- elif tab == "Recipe Generation":
60
- ingredients = st.text_area("Enter ingredients for recipe generation")
61
- if st.button("Generate Recipe"):
62
- recipe = generate_recipe(ingredients, recipe_model, recipe_tokenizer)
63
- st.write("Generated Recipe:", recipe)
64
 
65
- elif tab == "Instruction Following":
66
- instruction_prompt = st.text_area("Enter an instruction prompt")
67
- if st.button("Generate Instruction"):
68
- instruction = generate_instruction(instruction_prompt, instruct_model, instruct_tokenizer)
69
- st.write("Generated Instruction:", instruction)
 
70
 
71
  if __name__ == "__main__":
72
  main()
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # Load the models and tokenizers
5
+ question_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
6
+ question_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
 
 
7
 
8
+ sql_tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
9
+ sql_model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
 
10
 
11
+ # Function to generate a question based on a table schema
12
+ def generate_question(tables):
13
+ # Convert table schema to string
14
+ table_str = ", ".join([f"{table}: ({', '.join(cols)})" for table, cols in tables.items()])
15
+ prompt = f"Generate a question based on the following table schema: {table_str}"
16
 
17
+ # Tokenize input and generate question
18
+ input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids
19
+ output = question_model.generate(input_ids, num_beams=5, max_length=50)
20
+ question = question_tokenizer.decode(output[0], skip_special_tokens=True)
21
+ return question
22
 
23
+ # Function to prepare input data for SQL generation
24
+ def prepare_sql_input(question, tables):
25
+ table_str = ", ".join([f"{table}({', '.join(cols)})" for table, cols in tables.items()])
26
+ prompt = f"Convert the question and table schema into an SQL query. Tables: {table_str}. Question: {question}"
27
 
28
+ input_ids = sql_tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
29
+ return input_ids
 
 
 
 
30
 
31
+ # Inference function for SQL generation
32
+ def generate_sql(question, tables):
33
+ input_data = prepare_sql_input(question, tables)
34
+ input_data = input_data.to(sql_model.device)
35
+ outputs = sql_model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
36
+ sql_query = sql_tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ return sql_query
38
 
39
+ # Streamlit UI
 
 
 
 
 
 
40
  def main():
41
+ st.title("Multi-Model: Text to SQL and Question Generation")
 
 
 
42
 
43
+ # Input table schema
44
+ tables_input = st.text_area("Enter table schemas (in JSON format):",
45
+ '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}')
46
+ try:
47
+ tables = eval(tables_input) # Convert string to dict safely
48
+ except:
49
+ tables = {}
50
 
51
+ # If tables are provided, generate a question
52
+ if tables:
53
+ generated_question = generate_question(tables)
54
+ st.write(f"Generated Question: {generated_question}")
 
55
 
56
+ # Input question manually if needed
57
+ question = st.text_area("Enter your question (optional):", generated_question if tables else "")
 
 
 
58
 
59
+ if st.button("Generate SQL Query"):
60
+ if question and tables:
61
+ sql_query = generate_sql(question, tables)
62
+ st.write(f"Generated SQL Query: {sql_query}")
63
+ else:
64
+ st.write("Please enter both a question and table schemas.")
65
 
66
  if __name__ == "__main__":
67
  main()