Adjoumani commited on
Commit
bdc200f
·
verified ·
1 Parent(s): b5f2809

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +529 -0
app.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import fitz
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import requests
7
+ from llama_index.llms.nvidia import NVIDIA
8
+ import streamlit as st
9
+ from llama_index.core import Settings
10
+ from llama_index.core import VectorStoreIndex, StorageContext
11
+ from llama_index.core.node_parser import SentenceSplitter
12
+ from llama_index.vector_stores.milvus import MilvusVectorStore
13
+ from llama_index.embeddings.nvidia import NVIDIAEmbedding
14
+
15
+ from pptx import Presentation
16
+ import subprocess
17
+ from llama_index.core import Document
18
+
19
+
20
+
21
+ def set_environment_variables():
22
+ """Set necessary environment variables."""
23
+ os.environ["NVIDIA_API_KEY"] = "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" #set API key
24
+
25
+ def get_b64_image_from_content(image_content):
26
+ """Convert image content to base64 encoded string."""
27
+ img = Image.open(BytesIO(image_content))
28
+ if img.mode != 'RGB':
29
+ img = img.convert('RGB')
30
+ buffered = BytesIO()
31
+ img.save(buffered, format="JPEG")
32
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
33
+
34
+ def is_graph(image_content):
35
+ """Determine if an image is a graph, plot, chart, or table."""
36
+ res = describe_image(image_content)
37
+ return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"])
38
+
39
+ def process_graph(image_content):
40
+ """Process a graph image and generate a description."""
41
+ deplot_description = process_graph_deplot(image_content)
42
+ mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct")
43
+ response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description)
44
+ return response.text
45
+
46
+ def describe_image(image_content):
47
+ """Generate a description of an image using NVIDIA API."""
48
+ image_b64 = get_b64_image_from_content(image_content)
49
+ invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b"
50
+ api_key = os.getenv("NVIDIA_API_KEY")
51
+
52
+ if not api_key:
53
+ raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.")
54
+
55
+ headers = {
56
+ "Authorization": f"Bearer {api_key}",
57
+ "Accept": "application/json"
58
+ }
59
+
60
+ payload = {
61
+ "messages": [
62
+ {
63
+ "role": "user",
64
+ "content": f'Describe what you see in this image. <img src="data:image/png;base64,{image_b64}" />'
65
+ }
66
+ ],
67
+ "max_tokens": 1024,
68
+ "temperature": 0.20,
69
+ "top_p": 0.70,
70
+ "seed": 0,
71
+ "stream": False
72
+ }
73
+
74
+ response = requests.post(invoke_url, headers=headers, json=payload)
75
+ return response.json()["choices"][0]['message']['content']
76
+
77
+ def process_graph_deplot(image_content):
78
+ """Process a graph image using NVIDIA's Deplot API."""
79
+ invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot"
80
+ image_b64 = get_b64_image_from_content(image_content)
81
+ api_key = os.getenv("NVIDIA_API_KEY")
82
+
83
+ if not api_key:
84
+ raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.")
85
+
86
+ headers = {
87
+ "Authorization": f"Bearer {api_key}",
88
+ "Accept": "application/json"
89
+ }
90
+
91
+ payload = {
92
+ "messages": [
93
+ {
94
+ "role": "user",
95
+ "content": f'Generate underlying data table of the figure below: <img src="data:image/png;base64,{image_b64}" />'
96
+ }
97
+ ],
98
+ "max_tokens": 1024,
99
+ "temperature": 0.20,
100
+ "top_p": 0.20,
101
+ "stream": False
102
+ }
103
+
104
+ response = requests.post(invoke_url, headers=headers, json=payload)
105
+ return response.json()["choices"][0]['message']['content']
106
+
107
+ def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1):
108
+ """Extract text above and below a given bounding box on a page."""
109
+ before_text, after_text = "", ""
110
+ vertical_threshold_distance = page_height * threshold_percentage
111
+ horizontal_threshold_distance = bbox.width * threshold_percentage
112
+
113
+ for block in text_blocks:
114
+ block_bbox = fitz.Rect(block[:4])
115
+ vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1))
116
+ horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0))
117
+
118
+ if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance:
119
+ if block_bbox.y1 < bbox.y0 and not before_text:
120
+ before_text = block[4]
121
+ elif block_bbox.y0 > bbox.y1 and not after_text:
122
+ after_text = block[4]
123
+ break
124
+
125
+ return before_text, after_text
126
+
127
+ def process_text_blocks(text_blocks, char_count_threshold=500):
128
+ """Group text blocks based on a character count threshold."""
129
+ current_group = []
130
+ grouped_blocks = []
131
+ current_char_count = 0
132
+
133
+ for block in text_blocks:
134
+ if block[-1] == 0: # Check if the block is of text type
135
+ block_text = block[4]
136
+ block_char_count = len(block_text)
137
+
138
+ if current_char_count + block_char_count <= char_count_threshold:
139
+ current_group.append(block)
140
+ current_char_count += block_char_count
141
+ else:
142
+ if current_group:
143
+ grouped_content = "\n".join([b[4] for b in current_group])
144
+ grouped_blocks.append((current_group[0], grouped_content))
145
+ current_group = [block]
146
+ current_char_count = block_char_count
147
+
148
+ # Append the last group
149
+ if current_group:
150
+ grouped_content = "\n".join([b[4] for b in current_group])
151
+ grouped_blocks.append((current_group[0], grouped_content))
152
+
153
+ return grouped_blocks
154
+
155
+ def save_uploaded_file(uploaded_file):
156
+ """Save an uploaded file to a temporary directory."""
157
+ temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp")
158
+ os.makedirs(temp_dir, exist_ok=True)
159
+ temp_file_path = os.path.join(temp_dir, uploaded_file.name)
160
+
161
+ with open(temp_file_path, "wb") as temp_file:
162
+ temp_file.write(uploaded_file.read())
163
+
164
+ return temp_file_path
165
+
166
+
167
+
168
+ # 2ème fichier du code
169
+
170
+
171
+
172
+
173
+ def get_pdf_documents(pdf_file):
174
+ """Process a PDF file and extract text, tables, and images."""
175
+ all_pdf_documents = []
176
+ ongoing_tables = {}
177
+
178
+ try:
179
+ f = fitz.open(stream=pdf_file.read(), filetype="pdf")
180
+ except Exception as e:
181
+ print(f"Error opening or processing the PDF file: {e}")
182
+ return []
183
+
184
+ for i in range(len(f)):
185
+ page = f[i]
186
+ text_blocks = [block for block in page.get_text("blocks", sort=True)
187
+ if block[-1] == 0 and not (block[1] < page.rect.height * 0.1 or block[3] > page.rect.height * 0.9)]
188
+ grouped_text_blocks = process_text_blocks(text_blocks)
189
+
190
+ table_docs, table_bboxes, ongoing_tables = parse_all_tables(pdf_file.name, page, i, text_blocks, ongoing_tables)
191
+ all_pdf_documents.extend(table_docs)
192
+
193
+ image_docs = parse_all_images(pdf_file.name, page, i, text_blocks)
194
+ all_pdf_documents.extend(image_docs)
195
+
196
+ for text_block_ctr, (heading_block, content) in enumerate(grouped_text_blocks, 1):
197
+ heading_bbox = fitz.Rect(heading_block[:4])
198
+ if not any(heading_bbox.intersects(table_bbox) for table_bbox in table_bboxes):
199
+ bbox = {"x1": heading_block[0], "y1": heading_block[1], "x2": heading_block[2], "x3": heading_block[3]}
200
+ text_doc = Document(
201
+ text=f"{heading_block[4]}\n{content}",
202
+ metadata={
203
+ **bbox,
204
+ "type": "text",
205
+ "page_num": i,
206
+ "source": f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}"
207
+ },
208
+ id_=f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}"
209
+ )
210
+ all_pdf_documents.append(text_doc)
211
+
212
+ f.close()
213
+ return all_pdf_documents
214
+
215
+ def parse_all_tables(filename, page, pagenum, text_blocks, ongoing_tables):
216
+ """Extract tables from a PDF page."""
217
+ table_docs = []
218
+ table_bboxes = []
219
+ try:
220
+ tables = page.find_tables(horizontal_strategy="lines_strict", vertical_strategy="lines_strict")
221
+ for tab in tables:
222
+ if not tab.header.external:
223
+ pandas_df = tab.to_pandas()
224
+ tablerefdir = os.path.join(os.getcwd(), "vectorstore/table_references")
225
+ os.makedirs(tablerefdir, exist_ok=True)
226
+ df_xlsx_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.xlsx")
227
+ pandas_df.to_excel(df_xlsx_path)
228
+ bbox = fitz.Rect(tab.bbox)
229
+ table_bboxes.append(bbox)
230
+
231
+ before_text, after_text = extract_text_around_item(text_blocks, bbox, page.rect.height)
232
+
233
+ table_img = page.get_pixmap(clip=bbox)
234
+ table_img_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.jpg")
235
+ table_img.save(table_img_path)
236
+ description = process_graph(table_img.tobytes())
237
+
238
+ caption = before_text.replace("\n", " ") + description + after_text.replace("\n", " ")
239
+ if before_text == "" and after_text == "":
240
+ caption = " ".join(tab.header.names)
241
+ table_metadata = {
242
+ "source": f"{filename[:-4]}-page{pagenum}-table{len(table_docs)+1}",
243
+ "dataframe": df_xlsx_path,
244
+ "image": table_img_path,
245
+ "caption": caption,
246
+ "type": "table",
247
+ "page_num": pagenum
248
+ }
249
+ all_cols = ", ".join(list(pandas_df.columns.values))
250
+ doc = Document(text=f"This is a table with the caption: {caption}\nThe columns are {all_cols}", metadata=table_metadata)
251
+ table_docs.append(doc)
252
+ except Exception as e:
253
+ print(f"Error during table extraction: {e}")
254
+ return table_docs, table_bboxes, ongoing_tables
255
+
256
+ def parse_all_images(filename, page, pagenum, text_blocks):
257
+ """Extract images from a PDF page."""
258
+ image_docs = []
259
+ image_info_list = page.get_image_info(xrefs=True)
260
+ page_rect = page.rect
261
+
262
+ for image_info in image_info_list:
263
+ xref = image_info['xref']
264
+ if xref == 0:
265
+ continue
266
+
267
+ img_bbox = fitz.Rect(image_info['bbox'])
268
+ if img_bbox.width < page_rect.width / 20 or img_bbox.height < page_rect.height / 20:
269
+ continue
270
+
271
+ extracted_image = page.parent.extract_image(xref)
272
+ image_data = extracted_image["image"]
273
+ imgrefpath = os.path.join(os.getcwd(), "vectorstore/image_references")
274
+ os.makedirs(imgrefpath, exist_ok=True)
275
+ image_path = os.path.join(imgrefpath, f"image{xref}-page{pagenum}.png")
276
+ with open(image_path, "wb") as img_file:
277
+ img_file.write(image_data)
278
+
279
+ before_text, after_text = extract_text_around_item(text_blocks, img_bbox, page.rect.height)
280
+ if before_text == "" and after_text == "":
281
+ continue
282
+
283
+ image_description = " "
284
+ if is_graph(image_data):
285
+ image_description = process_graph(image_data)
286
+
287
+ caption = before_text.replace("\n", " ") + image_description + after_text.replace("\n", " ")
288
+
289
+ image_metadata = {
290
+ "source": f"{filename[:-4]}-page{pagenum}-image{xref}",
291
+ "image": image_path,
292
+ "caption": caption,
293
+ "type": "image",
294
+ "page_num": pagenum
295
+ }
296
+ image_docs.append(Document(text="This is an image with the caption: " + caption, metadata=image_metadata))
297
+ return image_docs
298
+
299
+ def process_ppt_file(ppt_path):
300
+ """Process a PowerPoint file."""
301
+ pdf_path = convert_ppt_to_pdf(ppt_path)
302
+ images_data = convert_pdf_to_images(pdf_path)
303
+ slide_texts = extract_text_and_notes_from_ppt(ppt_path)
304
+ processed_data = []
305
+
306
+ for (image_path, page_num), (slide_text, notes) in zip(images_data, slide_texts):
307
+ if notes:
308
+ notes = "\n\nThe speaker notes for this slide are: " + notes
309
+
310
+ with open(image_path, 'rb') as image_file:
311
+ image_content = image_file.read()
312
+
313
+ image_description = " "
314
+ if is_graph(image_content):
315
+ image_description = process_graph(image_content)
316
+
317
+ image_metadata = {
318
+ "source": f"{os.path.basename(ppt_path)}",
319
+ "image": image_path,
320
+ "caption": slide_text + image_description + notes,
321
+ "type": "image",
322
+ "page_num": page_num
323
+ }
324
+ processed_data.append(Document(text="This is a slide with the text: " + slide_text + image_description, metadata=image_metadata))
325
+
326
+ return processed_data
327
+
328
+ def convert_ppt_to_pdf(ppt_path):
329
+ """Convert a PowerPoint file to PDF using LibreOffice."""
330
+ base_name = os.path.basename(ppt_path)
331
+ ppt_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_')
332
+ new_dir_path = os.path.abspath("vectorstore/ppt_references")
333
+ os.makedirs(new_dir_path, exist_ok=True)
334
+ pdf_path = os.path.join(new_dir_path, f"{ppt_name_without_ext}.pdf")
335
+ command = ['libreoffice', '--headless', '--convert-to', 'pdf', '--outdir', new_dir_path, ppt_path]
336
+ subprocess.run(command, check=True)
337
+ return pdf_path
338
+
339
+ def convert_pdf_to_images(pdf_path):
340
+ """Convert a PDF file to a series of images using PyMuPDF."""
341
+ doc = fitz.open(pdf_path)
342
+ base_name = os.path.basename(pdf_path)
343
+ pdf_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_')
344
+ new_dir_path = os.path.join(os.getcwd(), "vectorstore/ppt_references")
345
+ os.makedirs(new_dir_path, exist_ok=True)
346
+ image_paths = []
347
+
348
+ for page_num in range(len(doc)):
349
+ page = doc.load_page(page_num)
350
+ pix = page.get_pixmap()
351
+ output_image_path = os.path.join(new_dir_path, f"{pdf_name_without_ext}_{page_num:04d}.png")
352
+ pix.save(output_image_path)
353
+ image_paths.append((output_image_path, page_num))
354
+ doc.close()
355
+ return image_paths
356
+
357
+ def extract_text_and_notes_from_ppt(ppt_path):
358
+ """Extract text and notes from a PowerPoint file."""
359
+ prs = Presentation(ppt_path)
360
+ text_and_notes = []
361
+ for slide in prs.slides:
362
+ slide_text = ' '.join([shape.text for shape in slide.shapes if hasattr(shape, "text")])
363
+ try:
364
+ notes = slide.notes_slide.notes_text_frame.text if slide.notes_slide else ''
365
+ except:
366
+ notes = ''
367
+ text_and_notes.append((slide_text, notes))
368
+ return text_and_notes
369
+
370
+ def load_multimodal_data(files):
371
+ """Load and process multiple file types."""
372
+ documents = []
373
+ for file in files:
374
+ file_extension = os.path.splitext(file.name.lower())[1]
375
+ if file_extension in ('.png', '.jpg', '.jpeg'):
376
+ image_content = file.read()
377
+ image_text = describe_image(image_content)
378
+ doc = Document(text=image_text, metadata={"source": file.name, "type": "image"})
379
+ documents.append(doc)
380
+ elif file_extension == '.pdf':
381
+ try:
382
+ pdf_documents = get_pdf_documents(file)
383
+ documents.extend(pdf_documents)
384
+ except Exception as e:
385
+ print(f"Error processing PDF {file.name}: {e}")
386
+ elif file_extension in ('.ppt', '.pptx'):
387
+ try:
388
+ ppt_documents = process_ppt_file(save_uploaded_file(file))
389
+ documents.extend(ppt_documents)
390
+ except Exception as e:
391
+ print(f"Error processing PPT {file.name}: {e}")
392
+ else:
393
+ text = file.read().decode("utf-8")
394
+ doc = Document(text=text, metadata={"source": file.name, "type": "text"})
395
+ documents.append(doc)
396
+ return documents
397
+
398
+ def load_data_from_directory(directory):
399
+ """Load and process multiple file types from a directory."""
400
+ documents = []
401
+ for filename in os.listdir(directory):
402
+ filepath = os.path.join(directory, filename)
403
+ file_extension = os.path.splitext(filename.lower())[1]
404
+ print(filename)
405
+ if file_extension in ('.png', '.jpg', '.jpeg'):
406
+ with open(filepath, "rb") as image_file:
407
+ image_content = image_file.read()
408
+ image_text = describe_image(image_content)
409
+ doc = Document(text=image_text, metadata={"source": filename, "type": "image"})
410
+ print(doc)
411
+ documents.append(doc)
412
+ elif file_extension == '.pdf':
413
+ with open(filepath, "rb") as pdf_file:
414
+ try:
415
+ pdf_documents = get_pdf_documents(pdf_file)
416
+ documents.extend(pdf_documents)
417
+ except Exception as e:
418
+ print(f"Error processing PDF {filename}: {e}")
419
+ elif file_extension in ('.ppt', '.pptx'):
420
+ try:
421
+ ppt_documents = process_ppt_file(filepath)
422
+ documents.extend(ppt_documents)
423
+ print(ppt_documents)
424
+ except Exception as e:
425
+ print(f"Error processing PPT {filename}: {e}")
426
+ else:
427
+ with open(filepath, "r", encoding="utf-8") as text_file:
428
+ text = text_file.read()
429
+ doc = Document(text=text, metadata={"source": filename, "type": "text"})
430
+ documents.append(doc)
431
+ return documents
432
+
433
+
434
+ # 3ème fichier
435
+
436
+
437
+
438
+
439
+ # Set up the page configuration
440
+ st.set_page_config(layout="wide")
441
+
442
+ # Initialize settings
443
+ def initialize_settings():
444
+ Settings.embed_model = NVIDIAEmbedding(model="nvidia/nv-embedqa-e5-v5", truncate="END")
445
+ Settings.llm = NVIDIA(model="meta/llama-3.1-70b-instruct")
446
+ Settings.text_splitter = SentenceSplitter(chunk_size=600)
447
+
448
+ # Create index from documents
449
+ def create_index(documents):
450
+ vector_store = MilvusVectorStore(
451
+ host = "127.0.0.1",
452
+ port = 19530,
453
+ dim = 1024
454
+ )
455
+ # vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1024, overwrite=True) #For CPU only vector store
456
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
457
+ return VectorStoreIndex.from_documents(documents, storage_context=storage_context)
458
+
459
+ # Main function to run the Streamlit app
460
+ def main():
461
+ set_environment_variables()
462
+ initialize_settings()
463
+
464
+ col1, col2 = st.columns([1, 2])
465
+
466
+ with col1:
467
+ st.title("Multimodal RAG")
468
+
469
+ input_method = st.radio("Choose input method:", ("Upload Files", "Enter Directory Path"))
470
+
471
+ if input_method == "Upload Files":
472
+ uploaded_files = st.file_uploader("Drag and drop files here", accept_multiple_files=True)
473
+ if uploaded_files and st.button("Process Files"):
474
+ with st.spinner("Processing files..."):
475
+ documents = load_multimodal_data(uploaded_files)
476
+ st.session_state['index'] = create_index(documents)
477
+ st.session_state['history'] = []
478
+ st.success("Files processed and index created!")
479
+ else:
480
+ directory_path = st.text_input("Enter directory path:")
481
+ if directory_path and st.button("Process Directory"):
482
+ if os.path.isdir(directory_path):
483
+ with st.spinner("Processing directory..."):
484
+ documents = load_data_from_directory(directory_path)
485
+ st.session_state['index'] = create_index(documents)
486
+ st.session_state['history'] = []
487
+ st.success("Directory processed and index created!")
488
+ else:
489
+ st.error("Invalid directory path. Please enter a valid path.")
490
+
491
+ with col2:
492
+ if 'index' in st.session_state:
493
+ st.title("Chat")
494
+ if 'history' not in st.session_state:
495
+ st.session_state['history'] = []
496
+
497
+ query_engine = st.session_state['index'].as_query_engine(similarity_top_k=5, streaming=True)
498
+
499
+ user_input = st.chat_input("Enter your query:")
500
+
501
+ # Display chat messages
502
+ chat_container = st.container()
503
+ with chat_container:
504
+ for message in st.session_state['history']:
505
+ with st.chat_message(message["role"]):
506
+ st.markdown(message["content"])
507
+
508
+ if user_input:
509
+ with st.chat_message("user"):
510
+ st.markdown(user_input)
511
+ st.session_state['history'].append({"role": "user", "content": user_input})
512
+
513
+ with st.chat_message("assistant"):
514
+ message_placeholder = st.empty()
515
+ full_response = ""
516
+ response = query_engine.query(user_input)
517
+ for token in response.response_gen:
518
+ full_response += token
519
+ message_placeholder.markdown(full_response + "▌")
520
+ message_placeholder.markdown(full_response)
521
+ st.session_state['history'].append({"role": "assistant", "content": full_response})
522
+
523
+ # Add a clear button
524
+ if st.button("Clear Chat"):
525
+ st.session_state['history'] = []
526
+ st.rerun()
527
+
528
+ if __name__ == "__main__":
529
+ main()