Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # Load the models | |
| tokenizer_sql = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") | |
| model_sql = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") | |
| tokenizer_question = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
| model_question = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
| # Function to create the prompt for SQL model | |
| def get_prompt_sql(tables, question): | |
| return f"""convert question and table into SQL query. tables: {tables}. question: {question}""" | |
| # Function to prepare input data for the SQL model | |
| def prepare_input_sql(question: str, tables: dict): | |
| tables = [f"""{table_name}({','.join(tables[table_name])})""" for table_name in tables] | |
| tables = ", ".join(tables) | |
| prompt = get_prompt_sql(tables, question) | |
| input_ids = tokenizer_sql(prompt, max_length=512, return_tensors="pt").input_ids | |
| return input_ids | |
| # Inference function for the SQL model | |
| def inference_sql(question: str, tables: dict) -> str: | |
| input_data = prepare_input_sql(question=question, tables=tables) | |
| input_data = input_data.to(model_sql.device) | |
| outputs = model_sql.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) | |
| return tokenizer_sql.decode(outputs[0], skip_special_tokens=True) | |
| # Function to create the prompt for Question Generation model | |
| def get_prompt_question(context): | |
| return f"generate a question from the following context: {context}" | |
| # Function to prepare input data for the Question Generation model | |
| def prepare_input_question(context: str): | |
| prompt = get_prompt_question(context) | |
| input_ids = tokenizer_question(prompt, max_length=512, return_tensors="pt").input_ids | |
| return input_ids | |
| # Inference function for the Question Generation model | |
| def inference_question(context: str) -> str: | |
| input_data = prepare_input_question(context) | |
| input_data = input_data.to(model_question.device) | |
| outputs = model_question.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) | |
| return tokenizer_question.decode(outputs[0], skip_special_tokens=True) | |
| # Streamlit UI | |
| def main(): | |
| st.title("Multi-Model: Text to SQL and Question Generation") | |
| # Model selection | |
| model_choice = st.selectbox("Select a model", ["Text to SQL", "Question Generation"]) | |
| # Input question and table schema for SQL model | |
| if model_choice == "Text to SQL": | |
| st.subheader("Text to SQL Model") | |
| question = st.text_area("Enter your question:") | |
| tables_input = st.text_area("Enter table schemas (in JSON format):", '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}') | |
| try: | |
| tables = eval(tables_input) # Convert string to dict safely | |
| except: | |
| tables = {} | |
| if st.button("Generate SQL Query"): | |
| if question and tables: | |
| sql_query = inference_sql(question, tables) | |
| st.write(f"Generated SQL Query: {sql_query}") | |
| else: | |
| st.write("Please enter both a question and table schemas.") | |
| # Input context for Question Generation model | |
| elif model_choice == "Question Generation": | |
| st.subheader("Question Generation Model") | |
| context = st.text_area("Enter context:") | |
| if st.button("Generate Question"): | |
| if context: | |
| generated_question = inference_question(context) | |
| st.write(f"Generated Question: {generated_question}") | |
| else: | |
| st.write("Please enter context for question generation.") | |
| if __name__ == "__main__": | |
| main() | |