AdrienB134 commited on
Commit
9d34725
·
verified ·
1 Parent(s): 5807d3e

Delete rag_demo

Browse files
Files changed (50) hide show
  1. rag_demo/__init__.py +0 -3
  2. rag_demo/__pycache__/__init__.cpython-311.pyc +0 -0
  3. rag_demo/__pycache__/pipeline.cpython-311.pyc +0 -0
  4. rag_demo/__pycache__/settings.cpython-311.pyc +0 -0
  5. rag_demo/app.py +0 -81
  6. rag_demo/data/test.pdf +0 -0
  7. rag_demo/data/test2.pdf +0 -3
  8. rag_demo/infra/__pycache__/qdrant.cpython-311.pyc +0 -0
  9. rag_demo/infra/qdrant.py +0 -25
  10. rag_demo/pipeline.py +0 -13
  11. rag_demo/preprocessing/__init__.py +0 -5
  12. rag_demo/preprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
  13. rag_demo/preprocessing/__pycache__/chunking.cpython-311.pyc +0 -0
  14. rag_demo/preprocessing/__pycache__/embed.cpython-311.pyc +0 -0
  15. rag_demo/preprocessing/__pycache__/load_to_vectordb.cpython-311.pyc +0 -0
  16. rag_demo/preprocessing/__pycache__/pdf_conversion.cpython-311.pyc +0 -0
  17. rag_demo/preprocessing/base/__init__.py +0 -12
  18. rag_demo/preprocessing/base/__pycache__/__init__.cpython-311.pyc +0 -0
  19. rag_demo/preprocessing/base/__pycache__/chunk.cpython-311.pyc +0 -0
  20. rag_demo/preprocessing/base/__pycache__/document.cpython-311.pyc +0 -0
  21. rag_demo/preprocessing/base/__pycache__/embedded_chunk.cpython-311.pyc +0 -0
  22. rag_demo/preprocessing/base/__pycache__/vectordb.cpython-311.pyc +0 -0
  23. rag_demo/preprocessing/base/chunk.py +0 -13
  24. rag_demo/preprocessing/base/document.py +0 -19
  25. rag_demo/preprocessing/base/embedded_chunk.py +0 -34
  26. rag_demo/preprocessing/base/embeddings.py +0 -45
  27. rag_demo/preprocessing/base/vectordb.py +0 -289
  28. rag_demo/preprocessing/chunking.py +0 -26
  29. rag_demo/preprocessing/embed.py +0 -57
  30. rag_demo/preprocessing/load_to_vectordb.py +0 -30
  31. rag_demo/preprocessing/pdf_conversion.py +0 -33
  32. rag_demo/rag/__pycache__/prompt_templates.cpython-311.pyc +0 -0
  33. rag_demo/rag/__pycache__/query_expansion.cpython-311.pyc +0 -0
  34. rag_demo/rag/__pycache__/reranker.cpython-311.pyc +0 -0
  35. rag_demo/rag/__pycache__/retriever.cpython-311.pyc +0 -0
  36. rag_demo/rag/base/__init__.py +0 -3
  37. rag_demo/rag/base/__pycache__/__init__.cpython-311.pyc +0 -0
  38. rag_demo/rag/base/__pycache__/query.cpython-311.pyc +0 -0
  39. rag_demo/rag/base/__pycache__/template_factory.cpython-311.pyc +0 -0
  40. rag_demo/rag/base/base.py +0 -22
  41. rag_demo/rag/base/query.py +0 -29
  42. rag_demo/rag/base/template_factory.py +0 -22
  43. rag_demo/rag/prompt_templates.py +0 -38
  44. rag_demo/rag/query_expansion.py +0 -39
  45. rag_demo/rag/reranker.py +0 -24
  46. rag_demo/rag/retriever.py +0 -133
  47. rag_demo/settings.py +0 -40
  48. rag_demo/static/Matriv-white.png +0 -0
  49. rag_demo/templates/chat.html +0 -333
  50. rag_demo/templates/upload.html +0 -193
rag_demo/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .infra.qdrant import connection
2
-
3
- __all__ = ["connection"]
 
 
 
 
rag_demo/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (177 Bytes)
 
rag_demo/__pycache__/pipeline.cpython-311.pyc DELETED
Binary file (738 Bytes)
 
rag_demo/__pycache__/settings.cpython-311.pyc DELETED
Binary file (1.99 kB)
 
rag_demo/app.py DELETED
@@ -1,81 +0,0 @@
1
- from fastapi import FastAPI, File, UploadFile, Request
2
- from fastapi.templating import Jinja2Templates
3
- from fastapi.responses import HTMLResponse
4
- from fastapi.staticfiles import StaticFiles
5
- from pydantic import BaseModel
6
- import os
7
- from pipeline import process_pdf
8
- import nest_asyncio
9
- from rag.retriever import RAGPipeline
10
- from loguru import logger
11
-
12
- app = FastAPI()
13
-
14
- # Apply nest_asyncio at the start of the application
15
- nest_asyncio.apply()
16
-
17
- # Create templates directory if it doesn't exist
18
- templates = Jinja2Templates(directory="templates")
19
-
20
- app.mount("/static", StaticFiles(directory="static"), name="static")
21
-
22
-
23
- class ChatRequest(BaseModel):
24
- question: str
25
-
26
-
27
- @app.get("/", response_class=HTMLResponse)
28
- async def upload_page(request: Request):
29
- return templates.TemplateResponse("upload.html", {"request": request})
30
-
31
-
32
- @app.get("/chat", response_class=HTMLResponse)
33
- async def chat_page(request: Request):
34
- return templates.TemplateResponse("chat.html", {"request": request})
35
-
36
-
37
- @app.post("/upload")
38
- async def upload_pdf(request: Request, file: UploadFile = File(...)):
39
- try:
40
- # Create uploads directory if it doesn't exist
41
- os.makedirs("data", exist_ok=True)
42
-
43
- file_path = f"data/{file.filename}"
44
- with open(file_path, "wb") as buffer:
45
- content = await file.read()
46
- buffer.write(content)
47
-
48
- # Process the PDF file with proper await statements
49
- await process_pdf(file_path)
50
-
51
- # Return template response with success message
52
- return templates.TemplateResponse(
53
- "upload.html",
54
- {
55
- "request": request,
56
- "message": f"Successfully processed {file.filename}",
57
- "processing": False,
58
- },
59
- )
60
- except Exception as e:
61
- return templates.TemplateResponse(
62
- "upload.html", {"request": request, "error": str(e), "processing": False}
63
- )
64
-
65
-
66
- @app.post("/chat")
67
- async def chat(chat_request: ChatRequest):
68
- rag_pipeline = RAGPipeline()
69
- try:
70
- answer = rag_pipeline.rag(chat_request.question)
71
- print(answer)
72
- logger.info(answer)
73
- return {"answer": answer}
74
- except Exception as e:
75
- return {"error": str(e)}
76
-
77
-
78
- if __name__ == "__main__":
79
- import uvicorn
80
-
81
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/data/test.pdf DELETED
Binary file (344 kB)
 
rag_demo/data/test2.pdf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b3041eb7dd274b02a2f18049891dc3f184dff4151796f225b92cd34d676ba923
3
- size 1962780
 
 
 
 
rag_demo/infra/__pycache__/qdrant.cpython-311.pyc DELETED
Binary file (1.32 kB)
 
rag_demo/infra/qdrant.py DELETED
@@ -1,25 +0,0 @@
1
- from loguru import logger
2
- from qdrant_client import QdrantClient
3
- from qdrant_client.http.exceptions import UnexpectedResponse
4
-
5
-
6
- class QdrantDatabaseConnector:
7
- _instance: QdrantClient | None = None
8
-
9
- def __new__(cls, *args, **kwargs) -> QdrantClient:
10
- if cls._instance is None:
11
- try:
12
- cls._instance = QdrantClient(":memory:")
13
-
14
- logger.info(f"Connection to Qdrant DB with URI successful")
15
- except:
16
- logger.exception(
17
- "Couldn't connect to Qdrant.",
18
- )
19
-
20
- raise
21
-
22
- return cls._instance
23
-
24
-
25
- connection = QdrantDatabaseConnector()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/pipeline.py DELETED
@@ -1,13 +0,0 @@
1
- from preprocessing import (
2
- convert_pdf_to_text,
3
- load_to_vector_db,
4
- chunk_and_embed,
5
- )
6
- from loguru import logger
7
-
8
-
9
- def process_pdf(file_path: str):
10
- convert = convert_pdf_to_text([file_path])
11
- embedded_chunks = chunk_and_embed([convert])
12
- load_to_vector_db(embedded_chunks)
13
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .pdf_conversion import convert_pdf_to_text
2
- from .load_to_vectordb import load_to_vector_db
3
- from .embed import chunk_and_embed
4
-
5
- __all__ = ["convert_pdf_to_text", "load_to_vector_db", "chunk_and_embed"]
 
 
 
 
 
 
rag_demo/preprocessing/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (441 Bytes)
 
rag_demo/preprocessing/__pycache__/chunking.cpython-311.pyc DELETED
Binary file (1.25 kB)
 
rag_demo/preprocessing/__pycache__/embed.cpython-311.pyc DELETED
Binary file (3.53 kB)
 
rag_demo/preprocessing/__pycache__/load_to_vectordb.cpython-311.pyc DELETED
Binary file (2.39 kB)
 
rag_demo/preprocessing/__pycache__/pdf_conversion.cpython-311.pyc DELETED
Binary file (1.8 kB)
 
rag_demo/preprocessing/base/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- from .document import Document, CleanedDocument
2
- from .chunk import Chunk
3
- from .embedded_chunk import EmbeddedChunk
4
- from .vectordb import VectorBaseDocument
5
-
6
- __all__ = [
7
- "Document",
8
- "CleanedDocument",
9
- "Chunk",
10
- "EmbeddedChunk",
11
- "VectorBaseDocument",
12
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/base/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (528 Bytes)
 
rag_demo/preprocessing/base/__pycache__/chunk.cpython-311.pyc DELETED
Binary file (927 Bytes)
 
rag_demo/preprocessing/base/__pycache__/document.cpython-311.pyc DELETED
Binary file (1.12 kB)
 
rag_demo/preprocessing/base/__pycache__/embedded_chunk.cpython-311.pyc DELETED
Binary file (2.04 kB)
 
rag_demo/preprocessing/base/__pycache__/vectordb.cpython-311.pyc DELETED
Binary file (16.7 kB)
 
rag_demo/preprocessing/base/chunk.py DELETED
@@ -1,13 +0,0 @@
1
- from abc import ABC
2
- from typing import Optional
3
-
4
- from pydantic import UUID4, Field
5
-
6
- from .vectordb import VectorBaseDocument
7
-
8
-
9
- class Chunk(VectorBaseDocument, ABC):
10
- content: str
11
- document_id: UUID4
12
- chunk_id: UUID4
13
- metadata: dict = Field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/base/document.py DELETED
@@ -1,19 +0,0 @@
1
- from abc import ABC
2
- from typing import Optional
3
-
4
- from pydantic import UUID4, BaseModel
5
-
6
- from .vectordb import VectorBaseDocument
7
-
8
-
9
- class CleanedDocument(VectorBaseDocument, ABC):
10
- content: str
11
- doc_id: UUID4
12
- doc_title: str
13
- # doc_url: str
14
-
15
-
16
- class Document(BaseModel):
17
- text: str
18
- document_id: UUID4
19
- metadata: dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/base/embedded_chunk.py DELETED
@@ -1,34 +0,0 @@
1
- from abc import ABC
2
-
3
- from pydantic import UUID4, Field
4
-
5
-
6
- from .vectordb import VectorBaseDocument
7
-
8
-
9
- class EmbeddedChunk(VectorBaseDocument, ABC):
10
- content: str
11
- embedding: list[float] | None
12
- document_id: UUID4
13
- chunk_id: UUID4
14
- metadata: dict = Field(default_factory=dict)
15
- similarity: float | None
16
-
17
- @classmethod
18
- def to_context(cls, chunks: list["EmbeddedChunk"]) -> str:
19
- context = ""
20
- for i, chunk in enumerate(chunks):
21
- context += f"""
22
- Chunk {i + 1}:
23
- Type: {chunk.__class__.__name__}
24
- Document ID: {chunk.document_id}
25
- Chunk ID: {chunk.chunk_id}
26
- Content: {chunk.content}\n
27
- """
28
-
29
- return context
30
-
31
- class Config:
32
- name = "embedded_documents"
33
- category = "Document"
34
- use_vector_index = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/base/embeddings.py DELETED
@@ -1,45 +0,0 @@
1
- from functools import cached_property
2
- from pathlib import Path
3
- from typing import Optional, ClassVar
4
-
5
-
6
-
7
-
8
-
9
- class SingletonMeta(type):
10
- """
11
- This is a thread-safe implementation of Singleton.
12
- """
13
-
14
- _instances: ClassVar = {}
15
-
16
- _lock: Lock = Lock()
17
-
18
- """
19
- We now have a lock object that will be used to synchronize threads during
20
- first access to the Singleton.
21
- """
22
-
23
- def __call__(cls, *args, **kwargs):
24
- """
25
- Possible changes to the value of the `__init__` argument do not affect
26
- the returned instance.
27
- """
28
- # Now, imagine that the program has just been launched. Since there's no
29
- # Singleton instance yet, multiple threads can simultaneously pass the
30
- # previous conditional and reach this point almost at the same time. The
31
- # first of them will acquire lock and will proceed further, while the
32
- # rest will wait here.
33
- with cls._lock:
34
- # The first thread to acquire the lock, reaches this conditional,
35
- # goes inside and creates the Singleton instance. Once it leaves the
36
- # lock block, a thread that might have been waiting for the lock
37
- # release may then enter this section. But since the Singleton field
38
- # is already initialized, the thread won't create a new object.
39
- if cls not in cls._instances:
40
- instance = super().__call__(*args, **kwargs)
41
- cls._instances[cls] = instance
42
-
43
- return cls._instances[cls]
44
-
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/base/vectordb.py DELETED
@@ -1,289 +0,0 @@
1
- import uuid
2
- from abc import ABC
3
- from typing import Any, Callable, Dict, Generic, Type, TypeVar
4
- from uuid import UUID
5
-
6
- import numpy as np
7
- from loguru import logger
8
- from pydantic import UUID4, BaseModel, Field
9
- from qdrant_client.http import exceptions
10
- from qdrant_client.http.models import Distance, VectorParams
11
- from qdrant_client.models import CollectionInfo, PointStruct, Record
12
-
13
-
14
- from rag_demo import connection
15
-
16
- T = TypeVar("T", bound="VectorBaseDocument")
17
-
18
- EMBEDDING_SIZE = 1024
19
-
20
-
21
- class VectorBaseDocument(BaseModel, Generic[T], ABC):
22
- id: UUID4 = Field(default_factory=uuid.uuid4)
23
-
24
- def __eq__(self, value: object) -> bool:
25
- if not isinstance(value, self.__class__):
26
- return False
27
-
28
- return self.id == value.id
29
-
30
- def __hash__(self) -> int:
31
- return hash(self.id)
32
-
33
- @classmethod
34
- def from_record(cls: Type[T], point: Record) -> T:
35
- _id = UUID(point.id, version=4)
36
- payload = point.payload or {}
37
-
38
- attributes = {
39
- "id": _id,
40
- **payload,
41
- }
42
- if cls._has_class_attribute("embedding"):
43
- attributes["embedding"] = point.vector or None
44
-
45
- return cls(**attributes)
46
-
47
- def to_point(self: T, **kwargs) -> PointStruct:
48
- exclude_unset = kwargs.pop("exclude_unset", False)
49
- by_alias = kwargs.pop("by_alias", True)
50
-
51
- payload = self.model_dump(
52
- exclude_unset=exclude_unset, by_alias=by_alias, **kwargs
53
- )
54
-
55
- _id = str(payload.pop("id"))
56
- vector = payload.pop("embedding", {})
57
- if vector and isinstance(vector, np.ndarray):
58
- vector = vector.tolist()
59
-
60
- return PointStruct(id=_id, vector=vector, payload=payload)
61
-
62
- def model_dump(self: T, **kwargs) -> dict:
63
- dict_ = super().model_dump(**kwargs)
64
-
65
- dict_ = self._uuid_to_str(dict_)
66
-
67
- return dict_
68
-
69
- def _uuid_to_str(self, item: Any) -> Any:
70
- if isinstance(item, dict):
71
- for key, value in item.items():
72
- if isinstance(value, UUID):
73
- item[key] = str(value)
74
- elif isinstance(value, list):
75
- item[key] = [self._uuid_to_str(v) for v in value]
76
- elif isinstance(value, dict):
77
- item[key] = {k: self._uuid_to_str(v) for k, v in value.items()}
78
-
79
- return item
80
-
81
- @classmethod
82
- def bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> bool:
83
- try:
84
- cls._bulk_insert(documents)
85
- logger.info(
86
- f"Successfully inserted {len(documents)} documents into {cls.get_collection_name()}"
87
- )
88
-
89
- except Exception as e:
90
- logger.error(f"Error inserting documents: {e}")
91
- logger.info(
92
- f"Collection '{cls.get_collection_name()}' does not exist. Trying to create the collection and reinsert the documents."
93
- )
94
-
95
- cls.create_collection()
96
-
97
- try:
98
- cls._bulk_insert(documents)
99
- except Exception as e:
100
- logger.error(f"Error inserting documents: {e}")
101
- logger.error(
102
- f"Failed to insert documents in '{cls.get_collection_name()}'."
103
- )
104
-
105
- return False
106
-
107
- return True
108
-
109
- @classmethod
110
- def _bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> None:
111
- points = [doc.to_point() for doc in documents]
112
-
113
- connection.upsert(collection_name=cls.get_collection_name(), points=points)
114
-
115
- @classmethod
116
- def bulk_find(
117
- cls: Type[T], limit: int = 10, **kwargs
118
- ) -> tuple[list[T], UUID | None]:
119
- try:
120
- documents, next_offset = cls._bulk_find(limit=limit, **kwargs)
121
- except exceptions.UnexpectedResponse:
122
- logger.error(
123
- f"Failed to search documents in '{cls.get_collection_name()}'."
124
- )
125
-
126
- documents, next_offset = [], None
127
-
128
- return documents, next_offset
129
-
130
- @classmethod
131
- def _bulk_find(
132
- cls: Type[T], limit: int = 10, **kwargs
133
- ) -> tuple[list[T], UUID | None]:
134
- collection_name = cls.get_collection_name()
135
-
136
- offset = kwargs.pop("offset", None)
137
- offset = str(offset) if offset else None
138
-
139
- records, next_offset = connection.scroll(
140
- collection_name=collection_name,
141
- limit=limit,
142
- with_payload=kwargs.pop("with_payload", True),
143
- with_vectors=kwargs.pop("with_vectors", False),
144
- offset=offset,
145
- **kwargs,
146
- )
147
- documents = [cls.from_record(record) for record in records]
148
- if next_offset is not None:
149
- next_offset = UUID(next_offset, version=4)
150
-
151
- return documents, next_offset
152
-
153
- @classmethod
154
- def search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]:
155
- try:
156
- documents = cls._search(query_vector=query_vector, limit=limit, **kwargs)
157
- except exceptions.UnexpectedResponse:
158
- logger.error(
159
- f"Failed to search documents in '{cls.get_collection_name()}'."
160
- )
161
-
162
- documents = []
163
-
164
- return documents
165
-
166
- @classmethod
167
- def _search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]:
168
- collection_name = cls.get_collection_name()
169
- records = connection.search(
170
- collection_name=collection_name,
171
- query_vector=query_vector,
172
- limit=limit,
173
- with_payload=kwargs.pop("with_payload", True),
174
- with_vectors=kwargs.pop("with_vectors", False),
175
- **kwargs,
176
- )
177
- documents = [cls.from_record(record) for record in records]
178
-
179
- return documents
180
-
181
- @classmethod
182
- def get_or_create_collection(cls: Type[T]) -> CollectionInfo:
183
- collection_name = cls.get_collection_name()
184
-
185
- try:
186
- return connection.get_collection(collection_name=collection_name)
187
- except exceptions.UnexpectedResponse:
188
- use_vector_index = cls.get_use_vector_index()
189
-
190
- collection_created = cls._create_collection(
191
- collection_name=collection_name, use_vector_index=use_vector_index
192
- )
193
- if collection_created is False:
194
- raise RuntimeError(
195
- f"Couldn't create collection {collection_name}"
196
- ) from None
197
-
198
- return connection.get_collection(collection_name=collection_name)
199
-
200
- @classmethod
201
- def create_collection(cls: Type[T]) -> bool:
202
- collection_name = cls.get_collection_name()
203
- use_vector_index = cls.get_use_vector_index()
204
- logger.info(
205
- f"Creating collection {collection_name} with use_vector_index={use_vector_index}"
206
- )
207
- return cls._create_collection(
208
- collection_name=collection_name, use_vector_index=use_vector_index
209
- )
210
-
211
- @classmethod
212
- def _create_collection(
213
- cls, collection_name: str, use_vector_index: bool = True
214
- ) -> bool:
215
- if use_vector_index is True:
216
- vectors_config = VectorParams(size=EMBEDDING_SIZE, distance=Distance.COSINE)
217
- else:
218
- vectors_config = {}
219
-
220
- return connection.create_collection(
221
- collection_name=collection_name, vectors_config=vectors_config
222
- )
223
-
224
- @classmethod
225
- def get_collection_name(cls: Type[T]) -> str:
226
- if not hasattr(cls, "Config") or not hasattr(cls.Config, "name"):
227
- raise Exception(
228
- f"The class {cls} should define a Config class with the 'name' property that reflects the collection's name."
229
- )
230
-
231
- return cls.Config.name
232
-
233
- @classmethod
234
- def get_use_vector_index(cls: Type[T]) -> bool:
235
- if not hasattr(cls, "Config") or not hasattr(cls.Config, "use_vector_index"):
236
- return True
237
-
238
- return cls.Config.use_vector_index
239
-
240
- @classmethod
241
- def group_by_class(
242
- cls: Type["VectorBaseDocument"], documents: list["VectorBaseDocument"]
243
- ) -> Dict["VectorBaseDocument", list["VectorBaseDocument"]]:
244
- return cls._group_by(documents, selector=lambda doc: doc.__class__)
245
-
246
- @classmethod
247
- def _group_by(
248
- cls: Type[T], documents: list[T], selector: Callable[[T], Any]
249
- ) -> Dict[Any, list[T]]:
250
- grouped = {}
251
- for doc in documents:
252
- key = selector(doc)
253
-
254
- if key not in grouped:
255
- grouped[key] = []
256
- grouped[key].append(doc)
257
-
258
- return grouped
259
-
260
- @classmethod
261
- def collection_name_to_class(
262
- cls: Type["VectorBaseDocument"], collection_name: str
263
- ) -> type["VectorBaseDocument"]:
264
- for subclass in cls.__subclasses__():
265
- try:
266
- if subclass.get_collection_name() == collection_name:
267
- return subclass
268
- except Exception:
269
- pass
270
-
271
- try:
272
- return subclass.collection_name_to_class(collection_name)
273
- except ValueError:
274
- continue
275
-
276
- raise ValueError(f"No subclass found for collection name: {collection_name}")
277
-
278
- @classmethod
279
- def _has_class_attribute(cls: Type[T], attribute_name: str) -> bool:
280
- if attribute_name in cls.__annotations__:
281
- return True
282
-
283
- for base in cls.__bases__:
284
- if hasattr(base, "_has_class_attribute") and base._has_class_attribute(
285
- attribute_name
286
- ):
287
- return True
288
-
289
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/chunking.py DELETED
@@ -1,26 +0,0 @@
1
- from uuid import uuid4
2
-
3
- from langchain.text_splitter import MarkdownTextSplitter
4
- from .base import Chunk
5
- from .base import Document
6
-
7
-
8
- def chunk_text(
9
- document: Document, chunk_size: int = 500, chunk_overlap: int = 50
10
- ) -> list[Chunk]:
11
- text_splitter = MarkdownTextSplitter(
12
- chunk_size=chunk_size, chunk_overlap=chunk_overlap
13
- )
14
- chunks = text_splitter.split_text(document.text)
15
- result = []
16
- for chunk in chunks:
17
- result.append(
18
- Chunk(
19
- content=chunk,
20
- document_id=document.document_id,
21
- chunk_id=uuid4(),
22
- metadata=document.metadata,
23
- )
24
- )
25
-
26
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/embed.py DELETED
@@ -1,57 +0,0 @@
1
- from typing_extensions import Annotated
2
- from typing import Generator
3
- from .base import Chunk
4
- from .base import EmbeddedChunk
5
- from .chunking import chunk_text
6
- from huggingface_hub import InferenceClient
7
- import os
8
- from dotenv import load_dotenv
9
- from uuid import uuid4
10
- from loguru import logger
11
-
12
- load_dotenv()
13
-
14
-
15
- def batch(list_: list, size: int) -> Generator[list, None, None]:
16
- yield from (list_[i : i + size] for i in range(0, len(list_), size))
17
-
18
-
19
- def embed_chunks(chunks: list[Chunk]) -> list[EmbeddedChunk]:
20
- api = InferenceClient(
21
- model="intfloat/multilingual-e5-large-instruct",
22
- token=os.getenv("HF_API_TOKEN"),
23
- )
24
- logger.info(f"Embedding {len(chunks)} chunks")
25
- embedded_chunks = []
26
- for chunk in chunks:
27
- try:
28
- embedded_chunks.append(
29
- EmbeddedChunk(
30
- id=uuid4(),
31
- content=chunk.content,
32
- embedding=api.feature_extraction(chunk.content),
33
- document_id=chunk.document_id,
34
- chunk_id=chunk.chunk_id,
35
- metadata=chunk.metadata,
36
- similarity=None,
37
- )
38
- )
39
- except Exception as e:
40
- logger.error(f"Error embedding chunk: {e}")
41
- logger.info(f"{len(embedded_chunks)} chunks embedded successfully")
42
-
43
- return embedded_chunks
44
-
45
-
46
- def chunk_and_embed(
47
- cleaned_documents: Annotated[list, "cleaned_documents"],
48
- ) -> Annotated[list, "embedded_documents"]:
49
- embedded_chunks = []
50
- for document in cleaned_documents:
51
- chunks = chunk_text(document)
52
-
53
- for batched_chunks in batch(chunks, 10):
54
- batched_embedded_chunks = embed_chunks(batched_chunks)
55
- embedded_chunks.extend(batched_embedded_chunks)
56
- logger.info(f"{len(embedded_chunks)} chunks embedded successfully")
57
- return embedded_chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/load_to_vectordb.py DELETED
@@ -1,30 +0,0 @@
1
- from loguru import logger
2
- from typing_extensions import Annotated
3
- from typing import Generator
4
-
5
- from .base import VectorBaseDocument
6
-
7
-
8
- def batch(list_: list, size: int) -> Generator[list, None, None]:
9
- yield from (list_[i : i + size] for i in range(0, len(list_), size))
10
-
11
-
12
- def load_to_vector_db(
13
- documents: Annotated[list, "documents"],
14
- ) -> Annotated[bool, "successful"]:
15
- logger.info(f"Loading {len(documents)} documents into the vector database.")
16
-
17
- grouped_documents = VectorBaseDocument.group_by_class(documents)
18
- for document_class, documents in grouped_documents.items():
19
- logger.info(f"Loading documents into {document_class.get_collection_name()}")
20
- for documents_batch in batch(documents, size=4):
21
- try:
22
- document_class.bulk_insert(documents_batch)
23
- except Exception as e:
24
- logger.error(
25
- f"Failed to insert documents into {document_class.get_collection_name()}: {e}"
26
- )
27
-
28
- return False
29
-
30
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/preprocessing/pdf_conversion.py DELETED
@@ -1,33 +0,0 @@
1
- from llama_parse import LlamaParse
2
- from llama_index.core import SimpleDirectoryReader
3
- from uuid import uuid4
4
- from .base import Document
5
- from loguru import logger
6
-
7
- from dotenv import load_dotenv
8
-
9
- load_dotenv()
10
-
11
-
12
- # set up parser
13
- parser = LlamaParse(
14
- api_key="llx-TN6YSXvZdpG0qhJ7rVx9QFg5Zq298RXr7Id7XzXb5Wr4Rnpt",
15
- result_type="markdown", # "markdown" and "text" are available
16
- )
17
-
18
-
19
- def convert_pdf_to_text(filepaths: list[str]) -> Document:
20
- file_extractor = {".pdf": parser}
21
- # use SimpleDirectoryReader to parse our file
22
-
23
- documents = SimpleDirectoryReader(
24
- input_files=filepaths, file_extractor=file_extractor
25
- ).load_data()
26
-
27
- logger.info("Converted 1 documents")
28
-
29
- return Document(
30
- document_id=uuid4(),
31
- text=" ".join(document.text for document in documents),
32
- metadata={"filename": filepaths[0].split("/")[-1]},
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/__pycache__/prompt_templates.cpython-311.pyc DELETED
Binary file (2.75 kB)
 
rag_demo/rag/__pycache__/query_expansion.cpython-311.pyc DELETED
Binary file (2.4 kB)
 
rag_demo/rag/__pycache__/reranker.cpython-311.pyc DELETED
Binary file (1.96 kB)
 
rag_demo/rag/__pycache__/retriever.cpython-311.pyc DELETED
Binary file (8.21 kB)
 
rag_demo/rag/base/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .template_factory import PromptTemplateFactory
2
-
3
- __all__ = [PromptTemplateFactory]
 
 
 
 
rag_demo/rag/base/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (283 Bytes)
 
rag_demo/rag/base/__pycache__/query.cpython-311.pyc DELETED
Binary file (2.08 kB)
 
rag_demo/rag/base/__pycache__/template_factory.cpython-311.pyc DELETED
Binary file (1.64 kB)
 
rag_demo/rag/base/base.py DELETED
@@ -1,22 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Any
3
-
4
- from langchain.prompts import PromptTemplate
5
- from pydantic import BaseModel
6
-
7
- from rag_demo.rag.base.query import Query
8
-
9
-
10
- class PromptTemplateFactory(ABC, BaseModel):
11
- @abstractmethod
12
- def create_template(self) -> PromptTemplate:
13
- pass
14
-
15
-
16
- class RAGStep(ABC):
17
- def __init__(self, mock: bool = False) -> None:
18
- self._mock = mock
19
-
20
- @abstractmethod
21
- def generate(self, query: Query, *args, **kwargs) -> Any:
22
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/base/query.py DELETED
@@ -1,29 +0,0 @@
1
- from pydantic import UUID4, Field
2
-
3
- from rag_demo.preprocessing.base import VectorBaseDocument
4
-
5
-
6
- class Query(VectorBaseDocument):
7
- content: str
8
- metadata: dict = Field(default_factory=dict)
9
-
10
- class Config:
11
- category = "query"
12
-
13
- @classmethod
14
- def from_str(cls, query: str) -> "Query":
15
- return Query(content=query.strip("\n "))
16
-
17
- def replace_content(self, new_content: str) -> "Query":
18
- return Query(
19
- id=self.id,
20
- content=new_content,
21
- metadata=self.metadata,
22
- )
23
-
24
-
25
- class EmbeddedQuery(Query):
26
- embedding: list[float]
27
-
28
- class Config:
29
- category = "query"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/base/template_factory.py DELETED
@@ -1,22 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Any
3
-
4
- from langchain.prompts import PromptTemplate
5
- from pydantic import BaseModel
6
-
7
- from .query import Query
8
-
9
-
10
- class PromptTemplateFactory(ABC, BaseModel):
11
- @abstractmethod
12
- def create_template(self) -> PromptTemplate:
13
- pass
14
-
15
-
16
- class RAGStep(ABC):
17
- def __init__(self, mock: bool = False) -> None:
18
- self._mock = mock
19
-
20
- @abstractmethod
21
- def generate(self, query: Query, *args, **kwargs) -> Any:
22
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/prompt_templates.py DELETED
@@ -1,38 +0,0 @@
1
- from langchain.prompts import PromptTemplate
2
-
3
- from .base import PromptTemplateFactory
4
-
5
-
6
- class QueryExpansionTemplate(PromptTemplateFactory):
7
- prompt: str = """You are an AI language model assistant. Your task is to generate {expand_to_n}
8
- different versions of the given user question to retrieve relevant documents from a vector
9
- database. By generating multiple perspectives on the user question, your goal is to help
10
- the user overcome some of the limitations of the distance-based similarity search.
11
- Provide these alternative questions seperated by '{separator}'.
12
- Original question: {question}"""
13
-
14
- @property
15
- def separator(self) -> str:
16
- return "#next-question#"
17
-
18
- def create_template(self, expand_to_n: int) -> PromptTemplate:
19
- return PromptTemplate(
20
- template=self.prompt,
21
- input_variables=["question"],
22
- partial_variables={
23
- "separator": self.separator,
24
- "expand_to_n": expand_to_n,
25
- },
26
- )
27
-
28
-
29
- class AnswerGenerationTemplate(PromptTemplateFactory):
30
- prompt: str = """You are an AI language model assistant. Your task is to generate an answer to the given user question based on the provided context.
31
- Context: {context}
32
- Question: {question}
33
-
34
- Give your answer in markdown format.
35
- Give only your answer, do not include any other text like 'Certainly! Here is the answer:' or 'The answer is:' or anything similar."""
36
-
37
- def create_template(self, context: str, question: str) -> str:
38
- return self.prompt.format(context=context, question=question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/query_expansion.py DELETED
@@ -1,39 +0,0 @@
1
- import os
2
- from typing import Any
3
-
4
- from huggingface_hub import InferenceClient
5
-
6
- from .base.query import Query
7
- from .base.template_factory import RAGStep
8
- from .prompt_templates import QueryExpansionTemplate
9
-
10
-
11
- class QueryExpansion(RAGStep):
12
- def generate(self, query: Query, expand_to_n: int) -> Any:
13
- api = InferenceClient(
14
- model="Qwen/Qwen2.5-72B-Instruct",
15
- token=os.getenv("HF_API_TOKEN"),
16
- )
17
- query_expansion_template = QueryExpansionTemplate()
18
- prompt = query_expansion_template.create_template(expand_to_n - 1)
19
- response = api.chat_completion(
20
- [
21
- {
22
- "role": "user",
23
- "content": prompt.template.format(
24
- question=query.content,
25
- expand_to_n=expand_to_n,
26
- separator=query_expansion_template.separator,
27
- ),
28
- }
29
- ]
30
- )
31
- result = response.choices[0].message.content
32
- queries_content = result.split(query_expansion_template.separator)
33
- queries = [query]
34
- queries += [
35
- query.replace_content(stripped_content)
36
- for content in queries_content
37
- if (stripped_content := content.strip())
38
- ]
39
- return queries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/reranker.py DELETED
@@ -1,24 +0,0 @@
1
- import os
2
-
3
- from huggingface_hub import InferenceClient
4
-
5
- from .base.query import Query
6
- from .base.template_factory import RAGStep
7
- from .preprocessing.embed import EmbeddedChunk
8
-
9
-
10
- class Reranker(RAGStep):
11
- def generate(
12
- self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int
13
- ) -> list[EmbeddedChunk]:
14
- api = InferenceClient(
15
- model="intfloat/multilingual-e5-large-instruct",
16
- token=os.getenv("HF_API_TOKEN"),
17
- )
18
- similarity = api.sentence_similarity(
19
- query.content, [chunk.content for chunk in chunks]
20
- )
21
- for chunk, sim in zip(chunks, similarity):
22
- chunk.similarity = sim
23
-
24
- return sorted(chunks, key=lambda x: x.similarity, reverse=True)[:keep_top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/rag/retriever.py DELETED
@@ -1,133 +0,0 @@
1
- import concurrent.futures
2
- import os
3
-
4
- from loguru import logger
5
- from qdrant_client.models import FieldCondition, Filter, MatchValue
6
- from huggingface_hub import InferenceClient
7
-
8
- from ..preprocessing.base import (
9
- EmbeddedChunk,
10
- )
11
- from .base.query import EmbeddedQuery, Query
12
-
13
- from .query_expansion import QueryExpansion
14
- from .reranker import Reranker
15
- from .prompt_templates import AnswerGenerationTemplate
16
-
17
- from dotenv import load_dotenv
18
-
19
- load_dotenv()
20
-
21
-
22
- def flatten(nested_list: list) -> list:
23
- """Flatten a list of lists into a single list."""
24
-
25
- return [item for sublist in nested_list for item in sublist]
26
-
27
-
28
- class RAGPipeline:
29
- def __init__(self, mock: bool = False) -> None:
30
- self._query_expander = QueryExpansion(mock=mock)
31
- self._reranker = Reranker(mock=mock)
32
-
33
- def search(
34
- self,
35
- query: str,
36
- k: int = 3,
37
- expand_to_n_queries: int = 3,
38
- ) -> list:
39
- query_model = Query.from_str(query)
40
-
41
- n_generated_queries = self._query_expander.generate(
42
- query_model, expand_to_n=expand_to_n_queries
43
- )
44
- logger.info(
45
- f"Successfully generated {len(n_generated_queries)} search queries.",
46
- )
47
-
48
- with concurrent.futures.ThreadPoolExecutor() as executor:
49
- search_tasks = [
50
- executor.submit(self._search, _query_model, k)
51
- for _query_model in n_generated_queries
52
- ]
53
-
54
- n_k_documents = [
55
- task.result() for task in concurrent.futures.as_completed(search_tasks)
56
- ]
57
- n_k_documents = flatten(n_k_documents)
58
- n_k_documents = list(set(n_k_documents))
59
-
60
- logger.info(f"{len(n_k_documents)} documents retrieved successfully")
61
-
62
- if len(n_k_documents) > 0:
63
- k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
64
- else:
65
- k_documents = []
66
-
67
- return k_documents
68
-
69
- def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
70
- assert k >= 3, "k should be >= 3"
71
-
72
- def _search_data(
73
- data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
74
- ) -> list[EmbeddedChunk]:
75
- return data_category_odm.search(
76
- query_vector=embedded_query.embedding,
77
- limit=k,
78
- )
79
-
80
- api = InferenceClient(
81
- model="intfloat/multilingual-e5-large-instruct",
82
- token=os.getenv("HF_API_TOKEN"),
83
- )
84
- embedded_query: EmbeddedQuery = EmbeddedQuery(
85
- embedding=api.feature_extraction(query.content),
86
- id=query.id,
87
- content=query.content,
88
- )
89
-
90
- retrieved_chunks = _search_data(EmbeddedChunk, embedded_query)
91
- logger.info(f"{len(retrieved_chunks)} documents retrieved successfully")
92
-
93
- return retrieved_chunks
94
-
95
- def rerank(
96
- self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int
97
- ) -> list[EmbeddedChunk]:
98
- if isinstance(query, str):
99
- query = Query.from_str(query)
100
-
101
- reranked_documents = self._reranker.generate(
102
- query=query, chunks=chunks, keep_top_k=keep_top_k
103
- )
104
-
105
- logger.info(f"{len(reranked_documents)} documents reranked successfully.")
106
-
107
- return reranked_documents
108
-
109
- def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str:
110
- context = ""
111
- for chunk in reranked_chunks:
112
- context += "\n Document: "
113
- context += chunk.content
114
- api = InferenceClient(
115
- model="meta-llama/Llama-3.1-8B-Instruct",
116
- token=os.getenv("HF_API_TOKEN"),
117
- )
118
- answer_generation_template = AnswerGenerationTemplate()
119
- prompt = answer_generation_template.create_template(context, query)
120
- logger.info(prompt)
121
- response = api.chat_completion(
122
- [{"role": "user", "content": prompt}],
123
- max_tokens=8192,
124
- )
125
- return response.choices[0].message.content
126
-
127
- def rag(self, query: str) -> tuple[str, list[str]]:
128
- docs = self.search(query, k=10)
129
- reranked_docs = self.rerank(query, docs, keep_top_k=10)
130
- return (
131
- self.generate_answer(query, reranked_docs),
132
- [doc.metadata["filename"].split(".pdf")[0] for doc in reranked_docs],
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/settings.py DELETED
@@ -1,40 +0,0 @@
1
- from loguru import logger
2
- from pydantic_settings import BaseSettings, SettingsConfigDict
3
-
4
-
5
- class Settings(BaseSettings):
6
- model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
7
-
8
- # Huggingface API
9
- HF_API_KEY: str | None = None
10
-
11
- # LlamaParse API
12
- LLAMA_PARSE_API_KEY: str | None = None
13
-
14
- # Qdrant vector database
15
- USE_QDRANT_CLOUD: bool = False
16
- QDRANT_DATABASE_HOST: str = "localhost"
17
- QDRANT_DATABASE_PORT: int = 6333
18
- QDRANT_CLOUD_URL: str = "str"
19
- QDRANT_APIKEY: str | None = None
20
-
21
- # RAG
22
- TEXT_EMBEDDING_MODEL_ID: str = "sentence-transformers/all-MiniLM-L6-v2"
23
- RERANKING_CROSS_ENCODER_MODEL_ID: str = "cross-encoder/ms-marco-MiniLM-L-4-v2"
24
- RAG_MODEL_DEVICE: str = "cpu"
25
-
26
- @classmethod
27
- def load_settings(cls) -> "Settings":
28
- """
29
- Tries to load the settings from the ZenML secret store. If the secret does not exist, it initializes the settings from the .env file and default values.
30
-
31
- Returns:
32
- Settings: The initialized settings object.
33
- """
34
-
35
- settings = Settings()
36
-
37
- return settings
38
-
39
-
40
- settings = Settings.load_settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/static/Matriv-white.png DELETED
Binary file (6.5 kB)
 
rag_demo/templates/chat.html DELETED
@@ -1,333 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
-
4
- <head>
5
- <title>RAG Chatbot</title>
6
- <style>
7
- :root {
8
- --primary-color: #a0a0a0;
9
- --background-color: #1a1a1a;
10
- --card-background: #2d2d2d;
11
- --text-color: #e0e0e0;
12
- --border-radius: 6px;
13
- --shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
14
- --input-background: #363636;
15
- --input-border: #404040;
16
- }
17
-
18
- body {
19
- font-family: 'Segoe UI', Arial, sans-serif;
20
- max-width: 1200px;
21
- margin: 0 auto;
22
- padding: 20px;
23
- background-color: var(--background-color);
24
- color: var(--text-color);
25
- }
26
-
27
- .card {
28
- background: var(--card-background);
29
- border-radius: var(--border-radius);
30
- box-shadow: var(--shadow);
31
- padding: 2rem;
32
- margin: 2rem 0;
33
- }
34
-
35
- .chat-container {
36
- background: var(--card-background);
37
- border-radius: var(--border-radius);
38
- padding: 1.5rem;
39
- height: 700px;
40
- overflow-y: auto;
41
- margin-bottom: 1.5rem;
42
- box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05);
43
- border: 1px solid var(--input-border);
44
- }
45
-
46
- .message {
47
- margin-bottom: 1rem;
48
- padding: 1rem;
49
- border-radius: 4px;
50
- max-width: 70%;
51
- animation: fadeIn 0.3s ease;
52
- }
53
-
54
- @keyframes fadeIn {
55
- from {
56
- opacity: 0;
57
- transform: translateY(10px);
58
- }
59
-
60
- to {
61
- opacity: 1;
62
- transform: translateY(0);
63
- }
64
- }
65
-
66
- .user-message {
67
- background-color: #808080;
68
- margin-left: auto;
69
- color: #ffffff;
70
- box-shadow: 0 2px 4px rgba(128, 128, 128, 0.2);
71
- }
72
-
73
- .bot-message {
74
- background-color: #363636;
75
- margin-right: auto;
76
- color: #e0e0e0;
77
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
78
- }
79
-
80
- .input-container {
81
- display: flex;
82
- gap: 12px;
83
- padding: 1rem;
84
- background: var(--card-background);
85
- border-radius: var(--border-radius);
86
- box-shadow: var(--shadow);
87
- }
88
-
89
- .nav {
90
- background: var(--card-background);
91
- padding: 1rem;
92
- border-radius: var(--border-radius);
93
- box-shadow: var(--shadow);
94
- margin-bottom: 1rem;
95
- }
96
-
97
- .nav a {
98
- margin-right: 20px;
99
- text-decoration: none;
100
- color: var(--primary-color);
101
- font-weight: 500;
102
- padding: 0.5rem 1rem;
103
- border-radius: 4px;
104
- transition: all 0.3s ease;
105
- }
106
-
107
- .nav a:hover {
108
- background: #363636;
109
- }
110
-
111
- #messageInput {
112
- flex-grow: 1;
113
- padding: 12px;
114
- border: 2px solid var(--input-border);
115
- border-radius: 4px;
116
- font-size: 1rem;
117
- transition: all 0.3s ease;
118
- background: var(--input-background);
119
- color: var(--text-color);
120
- }
121
-
122
- #messageInput:focus {
123
- outline: none;
124
- border-color: var(--primary-color);
125
- box-shadow: 0 0 0 3px rgba(114, 137, 218, 0.1);
126
- }
127
-
128
- button {
129
- background: var(--primary-color);
130
- color: white;
131
- border: none;
132
- padding: 12px 24px;
133
- border-radius: 4px;
134
- cursor: pointer;
135
- font-size: 1rem;
136
- transition: all 0.3s ease;
137
- }
138
-
139
- button:hover {
140
- background: #909090;
141
- transform: translateY(-2px);
142
- }
143
-
144
- h1 {
145
- color: var(--primary-color);
146
- text-align: center;
147
- margin-bottom: 1.5rem;
148
- }
149
-
150
- /* Scrollbar styling */
151
- .chat-container::-webkit-scrollbar {
152
- width: 8px;
153
- }
154
-
155
- .chat-container::-webkit-scrollbar-track {
156
- background: #363636;
157
- }
158
-
159
- .chat-container::-webkit-scrollbar-thumb {
160
- background: #4a4a4a;
161
- }
162
-
163
- .chat-container::-webkit-scrollbar-thumb:hover {
164
- background: #5a5a5a;
165
- }
166
-
167
- /* Add these new styles */
168
- .main-container {
169
- display: flex;
170
- gap: 20px;
171
- height: calc(100vh - 100px);
172
- /* Adjust for nav and padding */
173
- }
174
-
175
- .chat-card {
176
- flex: 3;
177
- background: var(--card-background);
178
- border-radius: var(--border-radius);
179
- box-shadow: var(--shadow);
180
- padding: 2rem;
181
- margin: 1rem 0;
182
- display: flex;
183
- flex-direction: column;
184
- height: fit-content;
185
- }
186
-
187
- .sources-card {
188
- flex: 1;
189
- background: var(--card-background);
190
- border-radius: var(--border-radius);
191
- box-shadow: var(--shadow);
192
- padding: 2rem;
193
- margin: 1rem 0;
194
- min-width: 250px;
195
- display: flex;
196
- flex-direction: column;
197
- height: auto;
198
- }
199
-
200
- .source-item {
201
- padding: 10px;
202
- margin-bottom: 10px;
203
- background: var(--input-background);
204
- border-radius: var(--border-radius);
205
- font-size: 0.9rem;
206
- border: 1px solid var(--input-border);
207
- }
208
-
209
- .sources-title {
210
- color: var(--text-color);
211
- font-size: 1.2rem;
212
- margin-bottom: 1rem;
213
- padding-bottom: 0.5rem;
214
- border-bottom: 1px solid var(--input-border);
215
- }
216
-
217
- #sourcesContainer {
218
- flex: 1;
219
- overflow-y: auto;
220
- }
221
-
222
- .logo-container {
223
- display: flex;
224
- justify-content: center;
225
- align-items: center;
226
- margin-bottom: 1rem;
227
- }
228
- </style>
229
- </head>
230
-
231
- <body>
232
- <div class="nav">
233
- <a href="/">Upload</a>
234
- <a href="/chat">Chat</a>
235
- </div>
236
- <div class="main-container">
237
- <div class="chat-card">
238
- <div class="logo-container">
239
- <img src="./static/Matriv-white.png" alt="Matriv Logo" style="width: 100px; height: auto;">
240
- </div>
241
- <div class="chat-container" id="chatContainer">
242
- </div>
243
- <div class="input-container">
244
- <input type="text" id="messageInput" placeholder="Type your message...">
245
- <button onclick="sendMessage()">Send</button>
246
- </div>
247
- </div>
248
- <div class="sources-card">
249
- <h2 class="sources-title">Sources</h2>
250
- <div id="sourcesContainer"></div>
251
- </div>
252
- </div>
253
-
254
- <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
255
- <script>
256
- const chatContainer = document.getElementById('chatContainer');
257
- const messageInput = document.getElementById('messageInput');
258
- const sourcesContainer = document.getElementById('sourcesContainer');
259
-
260
- function addMessage(message, isUser) {
261
- const messageDiv = document.createElement('div');
262
- messageDiv.className = `message ${isUser ? 'user-message' : 'bot-message'}`;
263
- messageDiv.textContent = message;
264
- chatContainer.appendChild(messageDiv);
265
- chatContainer.scrollTop = chatContainer.scrollHeight;
266
- }
267
-
268
- function updateSources(sources) {
269
- sourcesContainer.innerHTML = '';
270
- if (sources && sources.length > 0) {
271
- sources.forEach(source => {
272
- const sourceDiv = document.createElement('div');
273
- sourceDiv.className = 'source-item';
274
- sourceDiv.textContent = source;
275
- sourcesContainer.appendChild(sourceDiv);
276
- });
277
- }
278
- }
279
-
280
- async function sendMessage() {
281
- const message = messageInput.value.trim();
282
- if (!message) return;
283
-
284
- addMessage(message, true);
285
- messageInput.value = '';
286
-
287
- try {
288
- const response = await fetch('/chat', {
289
- method: 'POST',
290
- headers: {
291
- 'Content-Type': 'application/json',
292
- },
293
- body: JSON.stringify({ question: message }),
294
- });
295
-
296
- const data = await response.json();
297
-
298
- if (data.error) {
299
- addMessage(data.error, false);
300
- return;
301
- }
302
-
303
- // Create a temporary div to render markdown
304
- const tempDiv = document.createElement('div');
305
- tempDiv.innerHTML = marked.parse(data.answer[0]);
306
-
307
- // Create message div with markdown content
308
- const messageDiv = document.createElement('div');
309
- messageDiv.className = 'message bot-message';
310
- messageDiv.innerHTML = tempDiv.innerHTML;
311
-
312
- chatContainer.appendChild(messageDiv);
313
- chatContainer.scrollTop = chatContainer.scrollHeight;
314
-
315
- // Update sources if they exist in the response
316
- if (data.answer[1]) {
317
- updateSources(data.answer[1]);
318
- }
319
- } catch (error) {
320
- console.error('Error:', error);
321
- addMessage('Sorry, there was an error processing your message.', false);
322
- }
323
- }
324
-
325
- messageInput.addEventListener('keypress', function (e) {
326
- if (e.key === 'Enter') {
327
- sendMessage();
328
- }
329
- });
330
- </script>
331
- </body>
332
-
333
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_demo/templates/upload.html DELETED
@@ -1,193 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
-
4
- <head>
5
- <title>PDF Upload</title>
6
- <style>
7
- :root {
8
- --primary-color: #a0a0a0;
9
- --background-color: #1a1a1a;
10
- --card-background: #2d2d2d;
11
- --text-color: #e0e0e0;
12
- --border-radius: 12px;
13
- --shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
14
- }
15
-
16
- body {
17
- font-family: 'Segoe UI', Arial, sans-serif;
18
- max-width: 900px;
19
- margin: 0 auto;
20
- padding: 20px;
21
- background-color: var(--background-color);
22
- color: var(--text-color);
23
- }
24
-
25
- .card {
26
- background: var(--card-background);
27
- border-radius: var(--border-radius);
28
- box-shadow: var(--shadow);
29
- padding: 2rem;
30
- margin: 2rem 0;
31
- }
32
-
33
- .upload-form {
34
- border: 2px dashed #404040;
35
- padding: 2rem;
36
- text-align: center;
37
- margin: 1.5rem 0;
38
- border-radius: var(--border-radius);
39
- background: #363636;
40
- transition: all 0.3s ease;
41
- }
42
-
43
- .upload-form:hover {
44
- border-color: var(--primary-color);
45
- background: #404040;
46
- }
47
-
48
- .nav {
49
- background: var(--card-background);
50
- padding: 1rem;
51
- border-radius: var(--border-radius);
52
- box-shadow: var(--shadow);
53
- margin-bottom: 2rem;
54
- }
55
-
56
- .nav a {
57
- margin-right: 20px;
58
- text-decoration: none;
59
- color: var(--primary-color);
60
- font-weight: 500;
61
- padding: 0.5rem 1rem;
62
- border-radius: 6px;
63
- transition: all 0.3s ease;
64
- }
65
-
66
- .nav a:hover {
67
- background: #363636;
68
- }
69
-
70
- h1 {
71
- color: var(--primary-color);
72
- text-align: center;
73
- margin-bottom: 1.5rem;
74
- }
75
-
76
- input[type="file"] {
77
- display: none;
78
- }
79
-
80
- .file-upload-label {
81
- display: inline-block;
82
- padding: 12px 24px;
83
- background: var(--primary-color);
84
- color: white;
85
- border-radius: 6px;
86
- cursor: pointer;
87
- transition: all 0.3s ease;
88
- }
89
-
90
- .file-upload-label:hover {
91
- background: #909090;
92
- }
93
-
94
- .selected-file {
95
- margin-top: 1rem;
96
- color: #b0b0b0;
97
- }
98
-
99
- button {
100
- background: var(--primary-color);
101
- color: white;
102
- border: none;
103
- padding: 12px 24px;
104
- border-radius: 6px;
105
- cursor: pointer;
106
- font-size: 1rem;
107
- transition: all 0.3s ease;
108
- margin-top: 1rem;
109
- }
110
-
111
- button:hover {
112
- background: #909090;
113
- transform: translateY(-2px);
114
- }
115
-
116
- .status-message {
117
- margin-top: 1rem;
118
- padding: 1rem;
119
- border-radius: 6px;
120
- text-align: center;
121
- }
122
-
123
- .success {
124
- background: #2e4a3d;
125
- color: #7ee2b8;
126
- }
127
-
128
- .error {
129
- background: #4a2e2e;
130
- color: #e27e7e;
131
- }
132
-
133
- .loading-placeholder {
134
- display: none;
135
- margin-top: 1rem;
136
- color: #b0b0b0;
137
- animation: pulse 1.5s infinite;
138
- }
139
-
140
- @keyframes pulse {
141
- 0% {
142
- opacity: 0.6;
143
- }
144
-
145
- 50% {
146
- opacity: 1;
147
- }
148
-
149
- 100% {
150
- opacity: 0.6;
151
- }
152
- }
153
- </style>
154
- </head>
155
-
156
- <body>
157
- <div class="nav">
158
- <a href="/">Upload</a>
159
- <a href="/chat">Chat</a>
160
- </div>
161
- <div class="card">
162
- <h1>Upload Documents</h1>
163
- <div class="upload-form">
164
- <form action="/upload" method="post" enctype="multipart/form-data" id="uploadForm">
165
- <label for="file-upload" class="file-upload-label">
166
- Choose PDF Files
167
- </label>
168
- <input id="file-upload" type="file" name="file" accept=".pdf" multiple onchange="updateFileName(this)">
169
- <div id="selectedFile" class="selected-file"></div>
170
- <div id="loadingPlaceholder" class="loading-placeholder">Processing file...</div>
171
- <button type="submit" onclick="showLoading()">Upload</button>
172
- </form>
173
- </div>
174
- </div>
175
-
176
- <script>
177
- function updateFileName(input) {
178
- const fileNames = Array.from(input.files)
179
- .map(file => file.name)
180
- .join(', ');
181
- document.getElementById('selectedFile').textContent = fileNames || 'No file selected';
182
- }
183
-
184
- function showLoading() {
185
- if (document.getElementById('file-upload').files.length > 0) {
186
- document.getElementById('selectedFile').style.display = 'none';
187
- document.getElementById('loadingPlaceholder').style.display = 'block';
188
- }
189
- }
190
- </script>
191
- </body>
192
-
193
- </html>