File size: 5,272 Bytes
62098f3
d7d1a81
 
 
 
 
 
 
62098f3
 
d7d1a81
f0b46d5
62098f3
d7d1a81
62098f3
c2d1e89
5f762bd
 
c2d1e89
 
d7d1a81
 
 
 
62098f3
 
 
 
d7d1a81
62098f3
 
 
 
 
 
 
 
 
d7d1a81
be79c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62098f3
be79c33
62098f3
d7d1a81
62098f3
 
 
 
 
 
d7d1a81
 
 
 
f0b46d5
402b346
eeaf9ed
d7d1a81
be79c33
 
 
 
 
 
 
 
 
 
 
eeaf9ed
be79c33
eeaf9ed
 
 
be79c33
eeaf9ed
be79c33
 
 
 
 
 
eeaf9ed
 
be79c33
 
eeaf9ed
be79c33
 
 
 
 
402b346
be79c33
402b346
be79c33
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM
)
import torch

# ----- Streamlit page config -----
st.set_page_config(page_title="Chat", layout="wide")

# ----- Sidebar: Model controls -----
st.sidebar.title("Model Controls")
model_options = {
    "1": "karthikeyan-r/calculation_model_11k",
    "2": "karthikeyan-r/slm-custom-model_6k"
}

model_choice = st.sidebar.selectbox(
    "Select Model",
    options=list(model_options.values())
)
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")

# ----- Session States -----
if "model" not in st.session_state:
    st.session_state["model"] = None
if "tokenizer" not in st.session_state:
    st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
    st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
    st.session_state["conversation"] = []

# ----- Load Model -----
def load_model():
    if st.session_state["model"] is None or st.session_state["tokenizer"] is None:
        with st.spinner("Loading model..."):
            try:
                if model_choice == model_options["1"]:
                    # Load the calculation model
                    tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
                    model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")

                    # Add special tokens if needed
                    if tokenizer.pad_token is None:
                        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                        model.resize_token_embeddings(len(tokenizer))
                    if tokenizer.eos_token is None:
                        tokenizer.add_special_tokens({'eos_token': '[EOS]'})
                        model.resize_token_embeddings(len(tokenizer))

                    model.config.pad_token_id = tokenizer.pad_token_id
                    model.config.eos_token_id = tokenizer.eos_token_id

                    st.session_state["model"] = model
                    st.session_state["tokenizer"] = tokenizer
                    st.session_state["qa_pipeline"] = None  # Not needed for calculation model

                elif model_choice == model_options["2"]:
                    # Load the T5 model for general QA
                    device = 0 if torch.cuda.is_available() else -1
                    model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
                    tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
                    qa_pipe = pipeline(
                        "text2text-generation",
                        model=model,
                        tokenizer=tokenizer,
                        device=device
                    )
                    st.session_state["model"] = model
                    st.session_state["tokenizer"] = tokenizer
                    st.session_state["qa_pipeline"] = qa_pipe

                st.success("Model loaded successfully and ready!")
            except Exception as e:
                st.error(f"Error loading model: {e}")

if load_model_button:
    load_model()

# ----- Clear Model -----
if clear_model_button:
    st.session_state["model"] = None
    st.session_state["tokenizer"] = None
    st.session_state["qa_pipeline"] = None
    st.success("Model cleared.")

# ----- Clear Conversation -----
if clear_conversation_button:
    st.session_state["conversation"] = []
    st.success("Conversation cleared.")

# ----- Title -----
st.title("Chat Conversation UI")

# ----- User Input and Processing -----
user_input = st.chat_input("Enter your query:")
if user_input:
    # Save user input
    st.session_state["conversation"].append({
        "role": "user",
        "content": user_input
    })

    # Generate response
    if st.session_state["qa_pipeline"]:
        try:
            response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=250)
            answer = response[0]["generated_text"]
        except Exception as e:
            answer = f"Error: {str(e)}"
    elif st.session_state["model"] and model_choice == model_options["1"]:
        try:
            tokenizer = st.session_state["tokenizer"]
            model = st.session_state["model"]

            inputs = tokenizer(f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True)
            output = model.generate(inputs.input_ids, max_length=250, pad_token_id=tokenizer.pad_token_id)
            answer = tokenizer.decode(output[0], skip_special_tokens=True).split("Output:")[-1].strip()
        except Exception as e:
            answer = f"Error: {str(e)}"
    else:
        answer = "No model is loaded. Please select and load a model."

    # Save assistant response
    st.session_state["conversation"].append({
        "role": "assistant",
        "content": answer
    })

# Display conversation
for message in st.session_state["conversation"]:
    with st.chat_message(message["role"]):
        st.write(message["content"])