Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, Header, Depends | |
from pydantic import BaseModel | |
import os | |
from pymongo import MongoClient | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi.middleware.cors import CORSMiddleware | |
load_dotenv() | |
# MongoDB connection and Langchain setup (as provided) | |
MONGODB_ATLAS_CLUSTER_URI = os.getenv("MONGODB_ATLAS_CLUSTER_URI", None) | |
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) | |
DB_NAME = "quran_db" | |
COLLECTION_NAME = "tafsir" | |
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain_index" | |
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME] | |
embeddings = SentenceTransformerEmbeddings(model_name="BAAI/bge-m3") | |
vector_search = MongoDBAtlasVectorSearch.from_connection_string( | |
MONGODB_ATLAS_CLUSTER_URI, | |
DB_NAME + "." + COLLECTION_NAME, | |
embeddings, | |
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, | |
) | |
# FastAPI application setup | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def index_file(filepath): | |
""" Index each line in a file for quick search. Returns a dictionary with key as content and value as line number. """ | |
index = {} | |
with open(filepath, 'r', encoding='utf-8') as file: | |
for line_number, line in enumerate(file, 1): # Starting line numbers at 1 for human readability | |
index[line.strip()] = line_number | |
return index | |
def get_text_by_line_number(filepath, line_numbers): | |
""" Retrieve specific lines from a file based on line numbers. """ | |
lines = {} | |
with open(filepath, 'r', encoding='utf-8') as file: | |
for line_number, line in enumerate(file, 1): | |
if line_number in line_numbers: | |
lines[line_number] = line.strip() | |
if len(lines) == len(line_numbers): # Stop reading once all required lines are read | |
break | |
return lines | |
# Existing API endpoints | |
async def read_root(): | |
return {"message": "Welcome to our app"} | |
# New Query model for the POST request body | |
class Item(BaseModel): | |
question: str | |
EXPECTED_TOKEN = os.getenv("API_TOKEN") | |
def verify_token(authorization: str = Header(None)): | |
""" | |
Dependency to verify the Authorization header contains the correct Bearer token. | |
""" | |
# Prefix for bearer token in the Authorization header | |
prefix = "Bearer " | |
# Check if the Authorization header is present and correctly formatted | |
if not authorization or not authorization.startswith(prefix): | |
raise HTTPException(status_code=401, detail="Unauthorized: Missing or invalid token") | |
# Extract the token from the Authorization header | |
token = authorization[len(prefix):] | |
# Compare the extracted token to the expected token value | |
if token != EXPECTED_TOKEN: | |
raise HTTPException(status_code=401, detail="Unauthorized: Incorrect token") | |
# New API endpoint to get an answer using the chain | |
async def get_answer(item: Item, token: str = Depends(verify_token)): | |
try: | |
# Perform the similarity search with the provided question | |
matching_docs = vector_search.similarity_search(item.question, k=3) | |
clean_answers = [doc.page_content.replace("\n", " ").strip() for doc in matching_docs] | |
# Assuming 'search_file.txt' is where we want to search answers | |
answers_index = index_file('app/quran_tafseer_formatted.txt') | |
# Collect line numbers based on answers found | |
line_numbers = [answers_index[answer] for answer in clean_answers if answer in answers_index] | |
# Assuming 'retrieve_file.txt' is where we retrieve lines based on line numbers | |
result_text = get_text_by_line_number('app/quran_tafseer.txt', line_numbers) | |
return {"result_text": result_text} | |
except Exception as e: | |
# If there's an error, return a 500 error with the error's details | |
raise HTTPException(status_code=500, detail=str(e)) | |
# if __name__ == "__main__": | |
# uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=False) | |