Spaces:
Sleeping
Sleeping
File size: 3,741 Bytes
62098f3 409d83c 62098f3 c4c458c 62098f3 409d83c 62098f3 d371388 62098f3 c4c458c 62098f3 c4c458c 11d5f76 62098f3 a546bac 62098f3 a546bac 62098f3 c4c458c d371388 c4c458c d371388 c4c458c e559684 c4c458c d371388 c4c458c d371388 c4c458c 62098f3 e559684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch
# Streamlit app setup
st.set_page_config(page_title="Hugging Face Chat", layout="wide")
# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")
# Main UI
st.title("Chat Conversation UI")
# Session states
if "model" not in st.session_state:
st.session_state["model"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
st.session_state["conversation"] = []
if "user_input" not in st.session_state:
st.session_state["user_input"] = ""
# Load Model
if load_model_button:
with st.spinner("Loading model..."):
try:
device = 0 if torch.cuda.is_available() else -1
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
st.session_state["qa_pipeline"] = pipeline(
"text2text-generation",
model=st.session_state["model"],
tokenizer=st.session_state["tokenizer"],
device=device
)
st.success("Model loaded successfully and ready!")
except Exception as e:
st.error(f"Error loading model: {e}")
# Clear Model
if clear_model_button:
st.session_state["model"] = None
st.session_state["tokenizer"] = None
st.session_state["qa_pipeline"] = None
st.success("Model cleared.")
# Layout for chat
chat_container = st.container()
input_container = st.container()
# Chat Conversation Display
with chat_container:
st.subheader("Conversation")
for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
if speaker == "You":
st.write(f"You ({idx}):", message, key=f"you_{idx}", disabled=False)
else:
st.write(f"Model ({idx}):", message, key=f"model_{idx}", disabled=False)
# Input Area
with input_container:
if st.session_state["qa_pipeline"]:
user_input = st.text_input(
"Enter your query:",
value=st.session_state["user_input"], # Use session state for persistence
key="chat_input",
label_visibility="visible",
on_change=lambda: st.session_state.update({"user_input": st.session_state.chat_input}),
)
if st.button("Send", key="send_button"):
if st.session_state["user_input"]:
with st.spinner("Generating response..."):
try:
response = st.session_state["qa_pipeline"](f"Q: {st.session_state['user_input']}", max_length=400)
generated_text = response[0]["generated_text"]
st.session_state["conversation"].append(("You", st.session_state["user_input"]))
st.session_state["conversation"].append(("Model", generated_text))
st.session_state["user_input"] = "" # Clear input after submission
except Exception as e:
st.error(f"Error generating response: {e}")
# Clear Conversation
if clear_conversation_button:
st.session_state["conversation"] = []
st.success("Conversation cleared.")
|