Update main.py
Browse files
main.py
CHANGED
@@ -27,27 +27,25 @@ class QueryRequest(BaseModel):
|
|
27 |
|
28 |
def _unpack_faiss(src_path: str, extract_to: str) -> str:
|
29 |
"""
|
30 |
-
If src_path is a .zip, unzip
|
31 |
-
|
|
|
32 |
"""
|
33 |
-
# 1) ZIP
|
34 |
-
if
|
35 |
-
if not os.path.isfile(src_path):
|
36 |
-
raise FileNotFoundError(f"Could not find zip file: {src_path}")
|
37 |
with zipfile.ZipFile(src_path, "r") as zf:
|
38 |
zf.extractall(extract_to)
|
39 |
-
|
40 |
-
# walk until we find any .faiss file
|
41 |
for root, _, files in os.walk(extract_to):
|
42 |
-
if any(
|
43 |
return root
|
44 |
-
raise RuntimeError(f"No .faiss index found inside {src_path}")
|
45 |
|
46 |
-
# 2)
|
47 |
if os.path.isdir(src_path):
|
48 |
return src_path
|
49 |
|
50 |
-
raise RuntimeError(f"Path is neither a
|
51 |
|
52 |
|
53 |
@app.on_event("startup")
|
@@ -61,7 +59,6 @@ def load_components():
|
|
61 |
max_tokens=1024,
|
62 |
api_key=os.getenv("api_key"),
|
63 |
)
|
64 |
-
|
65 |
embeddings = HuggingFaceEmbeddings(
|
66 |
model_name="intfloat/multilingual-e5-large",
|
67 |
model_kwargs={"device": "cpu"},
|
@@ -69,28 +66,24 @@ def load_components():
|
|
69 |
)
|
70 |
|
71 |
# --- 2) Load & merge two FAISS indexes ---
|
72 |
-
|
73 |
-
|
|
|
74 |
|
75 |
-
# Use TemporaryDirectory objects so they stick around until program exit
|
76 |
tmp1 = tempfile.TemporaryDirectory()
|
77 |
tmp2 = tempfile.TemporaryDirectory()
|
78 |
|
79 |
-
# Unpack & locate
|
80 |
dir1 = _unpack_faiss(src1, tmp1.name)
|
81 |
dir2 = _unpack_faiss(src2, tmp2.name)
|
82 |
|
83 |
-
# Load them
|
84 |
vs1 = FAISS.load_local(dir1, embeddings, allow_dangerous_deserialization=True)
|
85 |
vs2 = FAISS.load_local(dir2, embeddings, allow_dangerous_deserialization=True)
|
86 |
|
87 |
-
# Merge vs2 into vs1
|
88 |
vs1.merge_from(vs2)
|
89 |
vectorstore = vs1
|
90 |
|
91 |
# --- 3) Build retriever & QA chain ---
|
92 |
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
93 |
-
|
94 |
prompt = PromptTemplate(
|
95 |
template="""
|
96 |
You are an expert assistant on Islamic knowledge.
|
@@ -108,7 +101,6 @@ Your response:
|
|
108 |
""",
|
109 |
input_variables=["context", "question"],
|
110 |
)
|
111 |
-
|
112 |
chain = RetrievalQA.from_chain_type(
|
113 |
llm=llm,
|
114 |
chain_type="stuff",
|
|
|
27 |
|
28 |
def _unpack_faiss(src_path: str, extract_to: str) -> str:
|
29 |
"""
|
30 |
+
If src_path is a valid .zip archive, unzip it into extract_to and
|
31 |
+
return the subdirectory that contains the .faiss index.
|
32 |
+
If src_path is already a directory, return it directly.
|
33 |
"""
|
34 |
+
# 1) True ZIP file?
|
35 |
+
if zipfile.is_zipfile(src_path):
|
|
|
|
|
36 |
with zipfile.ZipFile(src_path, "r") as zf:
|
37 |
zf.extractall(extract_to)
|
38 |
+
# scan until we find any .faiss file
|
|
|
39 |
for root, _, files in os.walk(extract_to):
|
40 |
+
if any(f.endswith(".faiss") for f in files):
|
41 |
return root
|
42 |
+
raise RuntimeError(f"No .faiss index found inside ZIP: {src_path}")
|
43 |
|
44 |
+
# 2) Already a folder?
|
45 |
if os.path.isdir(src_path):
|
46 |
return src_path
|
47 |
|
48 |
+
raise RuntimeError(f"Path is neither a valid ZIP nor a directory: {src_path}")
|
49 |
|
50 |
|
51 |
@app.on_event("startup")
|
|
|
59 |
max_tokens=1024,
|
60 |
api_key=os.getenv("api_key"),
|
61 |
)
|
|
|
62 |
embeddings = HuggingFaceEmbeddings(
|
63 |
model_name="intfloat/multilingual-e5-large",
|
64 |
model_kwargs={"device": "cpu"},
|
|
|
66 |
)
|
67 |
|
68 |
# --- 2) Load & merge two FAISS indexes ---
|
69 |
+
# (these can be either real .zip files or existing folders)
|
70 |
+
src1 = "faiss_index.zip" # or "faiss_index" if it's already a folder
|
71 |
+
src2 = "faiss_index_extra.zip" # or "faiss_index_extra"
|
72 |
|
|
|
73 |
tmp1 = tempfile.TemporaryDirectory()
|
74 |
tmp2 = tempfile.TemporaryDirectory()
|
75 |
|
|
|
76 |
dir1 = _unpack_faiss(src1, tmp1.name)
|
77 |
dir2 = _unpack_faiss(src2, tmp2.name)
|
78 |
|
|
|
79 |
vs1 = FAISS.load_local(dir1, embeddings, allow_dangerous_deserialization=True)
|
80 |
vs2 = FAISS.load_local(dir2, embeddings, allow_dangerous_deserialization=True)
|
81 |
|
|
|
82 |
vs1.merge_from(vs2)
|
83 |
vectorstore = vs1
|
84 |
|
85 |
# --- 3) Build retriever & QA chain ---
|
86 |
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
|
|
87 |
prompt = PromptTemplate(
|
88 |
template="""
|
89 |
You are an expert assistant on Islamic knowledge.
|
|
|
101 |
""",
|
102 |
input_variables=["context", "question"],
|
103 |
)
|
|
|
104 |
chain = RetrievalQA.from_chain_type(
|
105 |
llm=llm,
|
106 |
chain_type="stuff",
|