Hammad712 commited on
Commit
376d7c4
·
verified ·
1 Parent(s): a347f56

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -21
main.py CHANGED
@@ -27,27 +27,25 @@ class QueryRequest(BaseModel):
27
 
28
  def _unpack_faiss(src_path: str, extract_to: str) -> str:
29
  """
30
- If src_path is a .zip, unzip to extract_to and return the directory
31
- containing the .faiss file. If it's already a folder, just return it.
 
32
  """
33
- # 1) ZIP case
34
- if src_path.lower().endswith(".zip"):
35
- if not os.path.isfile(src_path):
36
- raise FileNotFoundError(f"Could not find zip file: {src_path}")
37
  with zipfile.ZipFile(src_path, "r") as zf:
38
  zf.extractall(extract_to)
39
-
40
- # walk until we find any .faiss file
41
  for root, _, files in os.walk(extract_to):
42
- if any(fn.endswith(".faiss") for fn in files):
43
  return root
44
- raise RuntimeError(f"No .faiss index found inside {src_path}")
45
 
46
- # 2) directory case
47
  if os.path.isdir(src_path):
48
  return src_path
49
 
50
- raise RuntimeError(f"Path is neither a .zip nor a directory: {src_path}")
51
 
52
 
53
  @app.on_event("startup")
@@ -61,7 +59,6 @@ def load_components():
61
  max_tokens=1024,
62
  api_key=os.getenv("api_key"),
63
  )
64
-
65
  embeddings = HuggingFaceEmbeddings(
66
  model_name="intfloat/multilingual-e5-large",
67
  model_kwargs={"device": "cpu"},
@@ -69,28 +66,24 @@ def load_components():
69
  )
70
 
71
  # --- 2) Load & merge two FAISS indexes ---
72
- src1 = "faiss_index.zip"
73
- src2 = "faiss_index_extra.zip"
 
74
 
75
- # Use TemporaryDirectory objects so they stick around until program exit
76
  tmp1 = tempfile.TemporaryDirectory()
77
  tmp2 = tempfile.TemporaryDirectory()
78
 
79
- # Unpack & locate
80
  dir1 = _unpack_faiss(src1, tmp1.name)
81
  dir2 = _unpack_faiss(src2, tmp2.name)
82
 
83
- # Load them
84
  vs1 = FAISS.load_local(dir1, embeddings, allow_dangerous_deserialization=True)
85
  vs2 = FAISS.load_local(dir2, embeddings, allow_dangerous_deserialization=True)
86
 
87
- # Merge vs2 into vs1
88
  vs1.merge_from(vs2)
89
  vectorstore = vs1
90
 
91
  # --- 3) Build retriever & QA chain ---
92
  retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
93
-
94
  prompt = PromptTemplate(
95
  template="""
96
  You are an expert assistant on Islamic knowledge.
@@ -108,7 +101,6 @@ Your response:
108
  """,
109
  input_variables=["context", "question"],
110
  )
111
-
112
  chain = RetrievalQA.from_chain_type(
113
  llm=llm,
114
  chain_type="stuff",
 
27
 
28
  def _unpack_faiss(src_path: str, extract_to: str) -> str:
29
  """
30
+ If src_path is a valid .zip archive, unzip it into extract_to and
31
+ return the subdirectory that contains the .faiss index.
32
+ If src_path is already a directory, return it directly.
33
  """
34
+ # 1) True ZIP file?
35
+ if zipfile.is_zipfile(src_path):
 
 
36
  with zipfile.ZipFile(src_path, "r") as zf:
37
  zf.extractall(extract_to)
38
+ # scan until we find any .faiss file
 
39
  for root, _, files in os.walk(extract_to):
40
+ if any(f.endswith(".faiss") for f in files):
41
  return root
42
+ raise RuntimeError(f"No .faiss index found inside ZIP: {src_path}")
43
 
44
+ # 2) Already a folder?
45
  if os.path.isdir(src_path):
46
  return src_path
47
 
48
+ raise RuntimeError(f"Path is neither a valid ZIP nor a directory: {src_path}")
49
 
50
 
51
  @app.on_event("startup")
 
59
  max_tokens=1024,
60
  api_key=os.getenv("api_key"),
61
  )
 
62
  embeddings = HuggingFaceEmbeddings(
63
  model_name="intfloat/multilingual-e5-large",
64
  model_kwargs={"device": "cpu"},
 
66
  )
67
 
68
  # --- 2) Load & merge two FAISS indexes ---
69
+ # (these can be either real .zip files or existing folders)
70
+ src1 = "faiss_index.zip" # or "faiss_index" if it's already a folder
71
+ src2 = "faiss_index_extra.zip" # or "faiss_index_extra"
72
 
 
73
  tmp1 = tempfile.TemporaryDirectory()
74
  tmp2 = tempfile.TemporaryDirectory()
75
 
 
76
  dir1 = _unpack_faiss(src1, tmp1.name)
77
  dir2 = _unpack_faiss(src2, tmp2.name)
78
 
 
79
  vs1 = FAISS.load_local(dir1, embeddings, allow_dangerous_deserialization=True)
80
  vs2 = FAISS.load_local(dir2, embeddings, allow_dangerous_deserialization=True)
81
 
 
82
  vs1.merge_from(vs2)
83
  vectorstore = vs1
84
 
85
  # --- 3) Build retriever & QA chain ---
86
  retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
 
87
  prompt = PromptTemplate(
88
  template="""
89
  You are an expert assistant on Islamic knowledge.
 
101
  """,
102
  input_variables=["context", "question"],
103
  )
 
104
  chain = RetrievalQA.from_chain_type(
105
  llm=llm,
106
  chain_type="stuff",