prakharmishra2002 commited on
Commit
0d89bdf
·
verified ·
1 Parent(s): 4ae2bcd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import uuid
5
+ import time
6
+
7
+ # Page configuration
8
+ st.set_page_config(
9
+ page_title="ChatBot",
10
+ page_icon="💬",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded"
13
+ )
14
+
15
+ # Initialize session state variables
16
+ if "chat_history" not in st.session_state:
17
+ st.session_state.chat_history = {}
18
+ if "current_chat_id" not in st.session_state:
19
+ st.session_state.current_chat_id = None
20
+ if "messages" not in st.session_state:
21
+ st.session_state.messages = []
22
+
23
+ # Load model and tokenizer
24
+ @st.cache_resource
25
+ def load_model():
26
+ model_name = "facebook/blenderbot-400M-distill"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ model = AutoModelForCausalLM.from_pretrained(model_name)
29
+ return tokenizer, model
30
+
31
+ tokenizer, model = load_model()
32
+
33
+ # Function to generate response
34
+ def generate_response(prompt):
35
+ inputs = tokenizer(prompt, return_tensors="pt")
36
+ with torch.no_grad():
37
+ outputs = model.generate(
38
+ inputs.input_ids,
39
+ max_length=100,
40
+ num_return_sequences=1,
41
+ temperature=0.7,
42
+ top_p=0.9,
43
+ do_sample=True
44
+ )
45
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ return response
47
+
48
+ # Custom CSS
49
+ st.markdown("""
50
+ <style>
51
+ .main {
52
+ background-color: #f9f9f9;
53
+ }
54
+ .stTextInput>div>div>input {
55
+ background-color: white;
56
+ }
57
+ .chat-message {
58
+ padding: 1rem;
59
+ border-radius: 0.5rem;
60
+ margin-bottom: 1rem;
61
+ display: flex;
62
+ flex-direction: row;
63
+ align-items: flex-start;
64
+ }
65
+ .chat-message.user {
66
+ background-color: #f0f0f0;
67
+ }
68
+ .chat-message.bot {
69
+ background-color: #e6f7ff;
70
+ }
71
+ .chat-message .avatar {
72
+ width: 40px;
73
+ height: 40px;
74
+ border-radius: 50%;
75
+ object-fit: cover;
76
+ margin-right: 1rem;
77
+ }
78
+ .chat-message .message {
79
+ flex-grow: 1;
80
+ }
81
+ .sidebar-chat {
82
+ padding: 0.5rem;
83
+ border-radius: 0.5rem;
84
+ margin-bottom: 0.5rem;
85
+ cursor: pointer;
86
+ }
87
+ .sidebar-chat:hover {
88
+ background-color: #f0f0f0;
89
+ }
90
+ .sidebar-chat.active {
91
+ background-color: #e6f7ff;
92
+ font-weight: bold;
93
+ }
94
+ .stButton>button {
95
+ width: 100%;
96
+ }
97
+ </style>
98
+ """, unsafe_allow_html=True)
99
+
100
+ # Sidebar for chat history
101
+ with st.sidebar:
102
+ st.title("💬 Chats")
103
+
104
+ # New chat button
105
+ if st.button("+ New Chat"):
106
+ # Generate a new chat ID
107
+ new_chat_id = str(uuid.uuid4())
108
+ st.session_state.current_chat_id = new_chat_id
109
+ st.session_state.chat_history[new_chat_id] = {
110
+ "title": f"Chat {len(st.session_state.chat_history) + 1}",
111
+ "messages": []
112
+ }
113
+ st.session_state.messages = []
114
+ st.rerun()
115
+
116
+ st.markdown("---")
117
+
118
+ # Display chat history
119
+ for chat_id, chat_data in st.session_state.chat_history.items():
120
+ chat_class = "active" if chat_id == st.session_state.current_chat_id else ""
121
+ if st.sidebar.markdown(f"""
122
+ <div class="sidebar-chat {chat_class}" id="{chat_id}">
123
+ {chat_data["title"]}
124
+ </div>
125
+ """, unsafe_allow_html=True):
126
+ st.session_state.current_chat_id = chat_id
127
+ st.session_state.messages = chat_data["messages"]
128
+ st.rerun()
129
+
130
+ # Main chat interface
131
+ st.title("ChatBot")
132
+
133
+ # Initialize a new chat if none exists
134
+ if not st.session_state.current_chat_id and not st.session_state.chat_history:
135
+ new_chat_id = str(uuid.uuid4())
136
+ st.session_state.current_chat_id = new_chat_id
137
+ st.session_state.chat_history[new_chat_id] = {
138
+ "title": "New Chat",
139
+ "messages": []
140
+ }
141
+
142
+ # Display chat messages
143
+ if st.session_state.current_chat_id:
144
+ for i, message in enumerate(st.session_state.messages):
145
+ if message["role"] == "user":
146
+ st.markdown(f"""
147
+ <div class="chat-message user">
148
+ <img class="avatar" src="https://api.dicebear.com/7.x/bottts/svg?seed=user" alt="User Avatar">
149
+ <div class="message">{message["content"]}</div>
150
+ </div>
151
+ """, unsafe_allow_html=True)
152
+ else:
153
+ st.markdown(f"""
154
+ <div class="chat-message bot">
155
+ <img class="avatar" src="https://api.dicebear.com/7.x/bottts/svg?seed=bot" alt="Bot Avatar">
156
+ <div class="message">{message["content"]}</div>
157
+ </div>
158
+ """, unsafe_allow_html=True)
159
+
160
+ # Chat input
161
+ if prompt := st.chat_input("Type your message here..."):
162
+ if st.session_state.current_chat_id:
163
+ # Add user message to chat
164
+ st.session_state.messages.append({"role": "user", "content": prompt})
165
+
166
+ # Update chat history
167
+ st.session_state.chat_history[st.session_state.current_chat_id]["messages"] = st.session_state.messages
168
+
169
+ # Update chat title if it's the first message
170
+ if len(st.session_state.messages) == 1:
171
+ st.session_state.chat_history[st.session_state.current_chat_id]["title"] = prompt[:20] + "..." if len(prompt) > 20 else prompt
172
+
173
+ st.rerun()
174
+
175
+ # Generate and display bot response for the last user message
176
+ if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
177
+ with st.spinner("Thinking..."):
178
+ # Simulate thinking time
179
+ time.sleep(0.5)
180
+
181
+ # Generate response
182
+ response = generate_response(st.session_state.messages[-1]["content"])
183
+
184
+ # Add bot response to chat
185
+ st.session_state.messages.append({"role": "assistant", "content": response})
186
+
187
+ # Update chat history
188
+ st.session_state.chat_history[st.session_state.current_chat_id]["messages"] = st.session_state.messages
189
+
190
+ st.rerun()
191
+