Daemontatox commited on
Commit
c2c5723
·
verified ·
1 Parent(s): b10a3b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -317
app.py CHANGED
@@ -1,334 +1,109 @@
1
- import subprocess
2
-
3
- subprocess.run(
4
- 'pip install flash-attn --no-build-isolation',
5
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
6
- shell=True
7
- )
8
  import os
9
- import re
10
- import time
11
- import torch
12
- import spaces
 
 
13
  import gradio as gr
14
- from threading import Thread
15
- from transformers import (
16
- AutoModelForCausalLM,
17
- AutoTokenizer,
18
- BitsAndBytesConfig,
19
- TextIteratorStreamer
20
- )
21
 
22
- # Configuration Constants
23
- MODEL_ID = "CohereForAI/aya-expanse-32b"
24
- DEFAULT_SYSTEM_PROMPT = """You are a highly intelligent Bilingual assistant who is fluent in Arabic and English."""
25
- # UI Configuration
26
- TITLE = "<h1><center>Mawared T Assistant</center></h1>"
27
- PLACEHOLDER = "Ask me anything! I'll think through it step by step."
28
 
29
- CSS = """
30
- .duplicate-button {
31
- margin: auto !important;
32
- color: white !important;
33
- background: black !important;
34
- border-radius: 100vh !important;
35
- }
36
- h3 {
37
- text-align: center;
38
- }
39
- .message-wrap {
40
- overflow-x: auto;
41
- }
42
- .message-wrap p {
43
- margin-bottom: 1em;
44
- }
45
- .message-wrap pre {
46
- background-color: #f6f8fa;
47
- border-radius: 3px;
48
- padding: 16px;
49
- overflow-x: auto;
50
- }
51
- .message-wrap code {
52
- background-color: rgba(175,184,193,0.2);
53
- border-radius: 3px;
54
- padding: 0.2em 0.4em;
55
- font-family: monospace;
56
- }
57
- .custom-tag {
58
- color: #0066cc;
59
- font-weight: bold;
60
- }
61
- .chat-area {
62
- height: 500px !important;
63
- overflow-y: auto !important;
64
- }
65
- """
66
 
67
- def initialize_model():
68
- """Initialize the model with appropriate configurations"""
69
- quantization_config = BitsAndBytesConfig(
70
- load_in_4bit=True,
71
- bnb_4bit_compute_dtype=torch.bfloat16,
72
- bnb_4bit_use_double_quant=True
73
- )
74
 
75
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
76
- if tokenizer.pad_token_id is None:
77
- tokenizer.pad_token_id = tokenizer.eos_token_id
 
 
 
78
 
79
- model = AutoModelForCausalLM.from_pretrained(
80
- MODEL_ID,
81
- torch_dtype=torch.float16,
82
- device_map="cuda",
83
- attn_implementation="flash_attention_2",
84
- quantization_config=quantization_config
85
 
 
 
 
 
 
 
 
 
86
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- return model, tokenizer
 
 
 
 
 
 
 
89
 
90
- def format_text(text):
91
- """Format text with proper spacing and tag highlighting (but keep tags visible)"""
92
- tag_patterns = [
93
- (r'<Thinking>', '\n<Thinking>\n'),
94
- (r'</Thinking>', '\n</Thinking>\n'),
95
- (r'<Critique>', '\n<Critique>\n'),
96
- (r'</Critique>', '\n</Critique>\n'),
97
- (r'<Revising>', '\n<Revising>\n'),
98
- (r'</Revising>', '\n</Revising>\n'),
99
- (r'<Final>', '\n<Final>\n'),
100
- (r'</Final>', '\n</Final>\n')
101
- ]
102
-
103
- formatted = text
104
- for pattern, replacement in tag_patterns:
105
- formatted = re.sub(pattern, replacement, formatted)
106
-
107
- formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
108
-
109
- return formatted
110
 
111
- def format_chat_history(history):
112
- """Format chat history for display, keeping tags visible"""
113
- formatted = []
114
- for user_msg, assistant_msg in history:
115
- formatted.append(f"User: {user_msg}")
116
- if assistant_msg:
117
- formatted.append(f"Assistant: {assistant_msg}")
118
- return "\n\n".join(formatted)
119
-
120
- def create_examples():
121
- """Create example queries for the UI"""
122
- return [
123
- "Explain the concept of artificial intelligence.",
124
- "How does photosynthesis work?",
125
- "What are the main causes of climate change?",
126
- "Describe the process of protein synthesis.",
127
- "What are the key features of a democratic government?",
128
- "Explain the theory of relativity.",
129
- "How do vaccines work to prevent diseases?",
130
- "What are the major events of World War II?",
131
- "Describe the structure of a human cell.",
132
- "What is the role of DNA in genetics?"
133
- ]
134
 
135
- @spaces.GPU(duration=660)
136
- def chat_response(
137
- message: str,
138
- history: list,
139
- chat_display: str,
140
- system_prompt: str,
141
- temperature: float = 1.0,
142
- max_new_tokens: int = 4000,
143
- top_p: float = 0.8,
144
- top_k: int = 40,
145
- penalty: float = 1.2,
146
- ):
147
- """Generate chat responses, keeping tags visible in the output"""
148
- conversation = [
149
- {"role": "system", "content": system_prompt}
150
- ]
151
-
152
- for prompt, answer in history:
153
- conversation.extend([
154
- {"role": "user", "content": prompt},
155
- {"role": "assistant", "content": answer}
156
- ])
157
-
158
- conversation.append({"role": "user", "content": message})
159
-
160
- input_ids = tokenizer.apply_chat_template(
161
- conversation,
162
- add_generation_prompt=True,
163
- return_tensors="pt"
164
- ).to(model.device)
165
-
166
- streamer = TextIteratorStreamer(
167
- tokenizer,
168
- timeout=60.0,
169
- skip_prompt=True,
170
- skip_special_tokens=True
171
- )
172
-
173
- generate_kwargs = dict(
174
- input_ids=input_ids,
175
- max_new_tokens=max_new_tokens,
176
- do_sample=False if temperature == 0 else True,
177
- top_p=top_p,
178
- top_k=top_k,
179
- temperature=temperature,
180
- repetition_penalty=penalty,
181
- streamer=streamer,
182
- )
183
-
184
- buffer = ""
185
-
186
- with torch.no_grad():
187
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
188
- thread.start()
189
-
190
- history = history + [[message, ""]]
191
-
192
- for new_text in streamer:
193
- buffer += new_text
194
- formatted_buffer = format_text(buffer)
195
- history[-1][1] = formatted_buffer
196
- chat_display = format_chat_history(history)
197
-
198
- yield history, chat_display
199
 
200
- def process_example(example: str) -> tuple:
201
- """Process example query and return empty history and updated display"""
202
- return [], f"User: {example}\n\n"
203
 
204
- def main():
205
- """Main function to set up and launch the Gradio interface"""
206
- global model, tokenizer
207
- model, tokenizer = initialize_model()
208
-
209
- with gr.Blocks(css=CSS, theme="soft") as demo:
210
- gr.HTML(TITLE)
211
- gr.DuplicateButton(
212
- value="Duplicate Space for private use",
213
- elem_classes="duplicate-button"
214
- )
215
-
216
- with gr.Row():
217
- with gr.Column():
218
- chat_history = gr.State([])
219
- chat_display = gr.TextArea(
220
- value="",
221
- label="Chat History",
222
- interactive=False,
223
- elem_classes=["chat-area"],
224
- )
225
-
226
- message = gr.TextArea(
227
- placeholder=PLACEHOLDER,
228
- label="Your message",
229
- lines=3
230
- )
231
-
232
- with gr.Row():
233
- submit = gr.Button("Send")
234
- clear = gr.Button("Clear")
235
-
236
- with gr.Accordion("⚙️ Advanced Settings", open=False):
237
- system_prompt = gr.TextArea(
238
- value=DEFAULT_SYSTEM_PROMPT,
239
- label="System Prompt",
240
- lines=5,
241
- )
242
- temperature = gr.Slider(
243
- minimum=0,
244
- maximum=1,
245
- step=0.1,
246
- value=0.2,
247
- label="Temperature",
248
- )
249
- max_tokens = gr.Slider(
250
- minimum=128,
251
- maximum=32000,
252
- step=128,
253
- value=4000,
254
- label="Max Tokens",
255
- )
256
- top_p = gr.Slider(
257
- minimum=0.1,
258
- maximum=1.0,
259
- step=0.1,
260
- value=0.8,
261
- label="Top-p",
262
- )
263
- top_k = gr.Slider(
264
- minimum=1,
265
- maximum=100,
266
- step=1,
267
- value=40,
268
- label="Top-k",
269
- )
270
- penalty = gr.Slider(
271
- minimum=1.0,
272
- maximum=2.0,
273
- step=0.1,
274
- value=1.2,
275
- label="Repetition Penalty",
276
- )
277
-
278
- examples = gr.Examples(
279
- examples=create_examples(),
280
- inputs=[message],
281
- outputs=[chat_history, chat_display],
282
- fn=process_example,
283
- cache_examples=False,
284
- )
285
-
286
- # Set up event handlers
287
- submit_click = submit.click(
288
- chat_response,
289
- inputs=[
290
- message,
291
- chat_history,
292
- chat_display,
293
- system_prompt,
294
- temperature,
295
- max_tokens,
296
- top_p,
297
- top_k,
298
- penalty,
299
- ],
300
- outputs=[chat_history, chat_display],
301
- show_progress=True,
302
- )
303
-
304
- message.submit(
305
- chat_response,
306
- inputs=[
307
- message,
308
- chat_history,
309
- chat_display,
310
- system_prompt,
311
- temperature,
312
- max_tokens,
313
- top_p,
314
- top_k,
315
- penalty,
316
- ],
317
- outputs=[chat_history, chat_display],
318
- show_progress=True,
319
- )
320
-
321
- clear.click(
322
- lambda: ([], ""),
323
- outputs=[chat_history, chat_display],
324
- show_progress=True,
325
- )
326
-
327
- submit_click.then(lambda: "", outputs=message)
328
- message.submit(lambda: "", outputs=message)
329
-
330
- return demo
331
 
 
332
  if __name__ == "__main__":
333
- demo = main()
334
- demo.launch()
 
1
+ from langchain_community.vectorstores import Qdrant
2
+ from langchain_groq import ChatGroq
3
+ from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
 
4
  import os
5
+ from dotenv import load_dotenv
6
+ from langchain.prompts import ChatPromptTemplate
7
+ from langchain.schema.runnable import RunnablePassthrough
8
+ from langchain.schema.output_parser import StrOutputParser
9
+ from qdrant_client import QdrantClient, models
10
+ from langchain_qdrant import Qdrant
11
  import gradio as gr
 
 
 
 
 
 
 
12
 
13
+ # Load environment variables
14
+ load_dotenv()
 
 
 
 
15
 
16
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # HuggingFace Embeddings
19
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
 
 
 
 
 
20
 
21
+ # Qdrant Client Setup
22
+ client = QdrantClient(
23
+ url=os.getenv("QDRANT_URL"),
24
+ api_key=os.getenv("QDRANT_API_KEY"),
25
+ prefer_grpc=True
26
+ )
27
 
28
+ collection_name = "mawared"
 
 
 
 
 
29
 
30
+ # Try to create collection, handle if it already exists
31
+ try:
32
+ client.create_collection(
33
+ collection_name=collection_name,
34
+ vectors_config=models.VectorParams(
35
+ size=768, # GTE-large embedding size
36
+ distance=models.Distance.COSINE
37
+ ),
38
  )
39
+ print(f"Created new collection: {collection_name}")
40
+ except Exception as e:
41
+ if "already exists" in str(e):
42
+ print(f"Collection {collection_name} already exists, continuing...")
43
+ else:
44
+ raise e
45
+
46
+ # Create Qdrant vector store
47
+ db = Qdrant(
48
+ client=client,
49
+ collection_name=collection_name,
50
+ embeddings=embeddings,
51
+ )
52
+
53
+ # Create retriever
54
+ retriever = db.as_retriever(
55
+ search_type="similarity",
56
+ search_kwargs={"k": 5}
57
+ )
58
 
59
+ # LLM setup
60
+ llm = ChatGroq(
61
+ model="llama-3.3-70b-versatile",
62
+ temperature=0.1,
63
+ max_tokens=None,
64
+ timeout=None,
65
+ max_retries=2,
66
+ )
67
 
68
+ # Create prompt template
69
+ template = """
70
+ You are an expert assistant specializing in the LONG COT RAG. Your task is to answer the user's question strictly based on the provided context...
71
+ Context:
72
+ {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ Question:
75
+ {question}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ Answer:
78
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ prompt = ChatPromptTemplate.from_template(template)
 
 
81
 
82
+ # Create the RAG chain
83
+ rag_chain = (
84
+ {"context": retriever, "question": RunnablePassthrough()}
85
+ | prompt
86
+ | llm
87
+ | StrOutputParser()
88
+ )
89
+
90
+ # Define the Gradio function
91
+ def ask_question_gradio(question):
92
+ result = ""
93
+ for chunk in rag_chain.stream(question):
94
+ result += chunk
95
+ return result
96
+
97
+ # Create the Gradio interface
98
+ interface = gr.Interface(
99
+ fn=ask_question_gradio,
100
+ inputs="text",
101
+ outputs="text",
102
+ title="Mawared Expert Assistant",
103
+ description="Ask questions about the Mawared HR System or any related topic using Chain-of-Thought (CoT) and RAG principles.",
104
+ theme="compact",
105
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Launch Gradio app
108
  if __name__ == "__main__":
109
+ interface.launch()