File size: 3,741 Bytes
62098f3
 
 
 
409d83c
62098f3
 
 
 
 
 
 
 
 
 
c4c458c
62098f3
409d83c
62098f3
 
 
 
 
 
 
 
d371388
 
62098f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c458c
 
 
62098f3
c4c458c
 
 
11d5f76
62098f3
a546bac
62098f3
a546bac
62098f3
c4c458c
 
 
d371388
 
 
 
 
 
 
c4c458c
d371388
c4c458c
 
e559684
c4c458c
d371388
c4c458c
d371388
c4c458c
 
 
62098f3
 
 
e559684
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch

# Streamlit app setup
st.set_page_config(page_title="Hugging Face Chat", layout="wide")

# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")

# Main UI
st.title("Chat Conversation UI")

# Session states
if "model" not in st.session_state:
    st.session_state["model"] = None
if "tokenizer" not in st.session_state:
    st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
    st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
    st.session_state["conversation"] = []
if "user_input" not in st.session_state:
    st.session_state["user_input"] = ""

# Load Model
if load_model_button:
    with st.spinner("Loading model..."):
        try:
            device = 0 if torch.cuda.is_available() else -1
            st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
            st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
            st.session_state["qa_pipeline"] = pipeline(
                "text2text-generation",
                model=st.session_state["model"],
                tokenizer=st.session_state["tokenizer"],
                device=device
            )
            st.success("Model loaded successfully and ready!")
        except Exception as e:
            st.error(f"Error loading model: {e}")

# Clear Model
if clear_model_button:
    st.session_state["model"] = None
    st.session_state["tokenizer"] = None
    st.session_state["qa_pipeline"] = None
    st.success("Model cleared.")

# Layout for chat
chat_container = st.container()
input_container = st.container()

# Chat Conversation Display
with chat_container:
    st.subheader("Conversation")
    for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
        if speaker == "You":
            st.write(f"You ({idx}):", message, key=f"you_{idx}", disabled=False)
        else:
            st.write(f"Model ({idx}):", message, key=f"model_{idx}", disabled=False)

# Input Area
with input_container:
    if st.session_state["qa_pipeline"]:
        user_input = st.text_input(
            "Enter your query:",
            value=st.session_state["user_input"],  # Use session state for persistence
            key="chat_input",
            label_visibility="visible",
            on_change=lambda: st.session_state.update({"user_input": st.session_state.chat_input}),
        )
        if st.button("Send", key="send_button"):
            if st.session_state["user_input"]:
                with st.spinner("Generating response..."):
                    try:
                        response = st.session_state["qa_pipeline"](f"Q: {st.session_state['user_input']}", max_length=400)
                        generated_text = response[0]["generated_text"]
                        st.session_state["conversation"].append(("You", st.session_state["user_input"]))
                        st.session_state["conversation"].append(("Model", generated_text))
                        st.session_state["user_input"] = ""  # Clear input after submission
                    except Exception as e:
                        st.error(f"Error generating response: {e}")

# Clear Conversation
if clear_conversation_button:
    st.session_state["conversation"] = []
    st.success("Conversation cleared.")