Hammad712 commited on
Commit
4b478d6
·
verified ·
1 Parent(s): 3428cef

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +111 -0
main.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_groq import ChatGroq
8
+ import zipfile
9
+ import os
10
+
11
+ app = FastAPI()
12
+
13
+ # === Startup config ===
14
+ class QueryRequest(BaseModel):
15
+ question: str
16
+
17
+ llm = None
18
+ retriever = None
19
+ chain = None
20
+
21
+ @app.on_event("startup")
22
+ def load_components():
23
+ global llm, retriever, chain
24
+
25
+ api_key = os.getenv('api_key')
26
+
27
+ # --- Load LLM ---
28
+ llm = ChatGroq(
29
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
30
+ temperature=0,
31
+ max_tokens=1024,
32
+ api_key=api_key
33
+ )
34
+
35
+ # --- Load Embeddings ---
36
+ embeddings = HuggingFaceEmbeddings(
37
+ model_name="intfloat/multilingual-e5-large",
38
+ model_kwargs={"device": "cpu"},
39
+ encode_kwargs={"normalize_embeddings": True},
40
+ )
41
+
42
+ # --- Unzip Vectorstore if needed ---
43
+ zip_path = "faiss_index.zip"
44
+ extract_path = "faiss_index"
45
+ if not os.path.exists(extract_path):
46
+ with zipfile.ZipFile(zip_path, 'r') as z:
47
+ z.extractall(extract_path)
48
+ print("✅ Unzipped FAISS index.")
49
+
50
+ # --- Load FAISS Vectorstore & create retriever ---
51
+ vectorstore = FAISS.load_local(
52
+ extract_path,
53
+ embeddings,
54
+ allow_dangerous_deserialization=True
55
+ )
56
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
57
+ print("✅ FAISS index loaded.")
58
+
59
+ # --- Prepare prompt template ---
60
+ quiz_solving_prompt = """
61
+ You are an Arabic Hadith Finder assistant.
62
+ Your goal is to provide an accurate and concise answer extracted directly from the provided retrieved context.
63
+ Your task is to output only the exact Arabic Hadith (as it appears in the context), removing any extraneous or irrelevant data.
64
+
65
+ Instructions:
66
+ 1. Identify the segment in the retrieved context that directly answers the user's question.
67
+ 2. Output the Hadith exactly as it appears in Arabic in the context.
68
+ 3. Remove any information that does not pertain directly to the query.
69
+ 4. If the context does not contain sufficient information to answer the question, respond with "لا أعلم". Do not add or infer any extra information.
70
+ 5. Provide the complete reference of the Hadith (if available), including:
71
+ - Chapter Number and Name (Arabic and/or English)
72
+ - Section Number and Name
73
+ - Hadith Number
74
+ - Arabic Isnad and Matn
75
+ - Arabic Grade (if present)
76
+ - Hadith Book name
77
+
78
+ Retrieved context:
79
+ {context}
80
+
81
+ User's question:
82
+ {question}
83
+
84
+ Your response:
85
+ """
86
+ prompt = PromptTemplate(
87
+ template=quiz_solving_prompt,
88
+ input_variables=["context", "question"]
89
+ )
90
+
91
+ # --- Assemble a stateless RetrievalQA chain (no memory) ---
92
+ chain = RetrievalQA.from_chain_type(
93
+ llm=llm,
94
+ chain_type="stuff",
95
+ retriever=retriever,
96
+ return_source_documents=False,
97
+ chain_type_kwargs={"prompt": prompt},
98
+ verbose=False,
99
+ )
100
+
101
+ @app.get("/")
102
+ def root():
103
+ return {"message": "Arabic Hadith Finder API is up..."}
104
+
105
+ @app.post("/query")
106
+ def query(request: QueryRequest):
107
+ try:
108
+ result = chain.invoke({"query": request.question})
109
+ return {"answer": result["result"]}
110
+ except Exception as e:
111
+ raise HTTPException(status_code=500, detail=str(e))