Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline | |
| import torch | |
| # Streamlit app setup | |
| st.set_page_config(page_title="Hugging Face Chat", layout="wide") | |
| # Sidebar: Model controls | |
| st.sidebar.title("Model Controls") | |
| model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k") | |
| load_model_button = st.sidebar.button("Load Model") | |
| clear_conversation_button = st.sidebar.button("Clear Conversation") | |
| clear_model_button = st.sidebar.button("Clear Model") | |
| # Main UI | |
| st.title("Chat Conversation UI") | |
| # Session states | |
| if "model" not in st.session_state: | |
| st.session_state["model"] = None | |
| if "tokenizer" not in st.session_state: | |
| st.session_state["tokenizer"] = None | |
| if "qa_pipeline" not in st.session_state: | |
| st.session_state["qa_pipeline"] = None | |
| if "conversation" not in st.session_state: | |
| st.session_state["conversation"] = [] | |
| # Load Model | |
| if load_model_button: | |
| with st.spinner("Loading model..."): | |
| try: | |
| device = 0 if torch.cuda.is_available() else -1 | |
| st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache") | |
| st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache") | |
| st.session_state["qa_pipeline"] = pipeline( | |
| "text2text-generation", | |
| model=st.session_state["model"], | |
| tokenizer=st.session_state["tokenizer"], | |
| device=device | |
| ) | |
| st.success("Model loaded successfully and ready!") | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| # Clear Model | |
| if clear_model_button: | |
| st.session_state["model"] = None | |
| st.session_state["tokenizer"] = None | |
| st.session_state["qa_pipeline"] = None | |
| st.success("Model cleared.") | |
| # Layout for chat | |
| chat_container = st.container() | |
| input_container = st.container() | |
| # Chat Conversation Display | |
| with chat_container: | |
| st.subheader("Conversation") | |
| for idx, (speaker, message) in enumerate(st.session_state["conversation"]): | |
| if speaker == "You": | |
| st.text_area(f"You ({idx}):", message, key=f"you_{idx}", disabled=False) | |
| else: | |
| st.text_area(f"Model ({idx}):", message, key=f"model_{idx}", disabled=False) | |
| # Input Area | |
| with input_container: | |
| if st.session_state["qa_pipeline"]: | |
| user_input = st.text_input("Enter your query:", key="chat_input", label_visibility="visible") | |
| if st.button("Send", key="send_button"): | |
| if user_input: | |
| with st.spinner("Generating response..."): | |
| try: | |
| response = st.session_state["qa_pipeline"](f"Q: {user_input", max_length=300) | |
| generated_text = response[0]["generated_text"] | |
| st.session_state["conversation"].append(("You", user_input)) | |
| st.session_state["conversation"].append(("Model", generated_text)) | |
| except Exception as e: | |
| st.error(f"Error generating response: {e}") | |
| # Clear Conversation | |
| if clear_conversation_button: | |
| st.session_state["conversation"] = [] | |
| st.success("Conversation cleared.") | |