Spaces:
Sleeping
Sleeping
File size: 5,093 Bytes
e13455c 1a102e1 e13455c dc43c61 e13455c 5b25c6e e13455c dc43c61 642181a 886c1e1 642181a 886c1e1 642181a 886c1e1 642181a 886c1e1 642181a 886c1e1 642181a e13455c 642181a 706b71b 642181a 886c1e1 e13455c 642181a e13455c dc43c61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from fastapi import FastAPI, HTTPException, Header, Depends
from pydantic import BaseModel
import os
from pymongo import MongoClient
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
import uvicorn
from dotenv import load_dotenv
from fastapi.middleware.cors import CORSMiddleware
load_dotenv()
# MongoDB connection
MONGODB_ATLAS_CLUSTER_URI = os.getenv("MONGODB_ATLAS_CLUSTER_URI", None)
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
DB_NAME = "quran_db"
COLLECTION_NAME = "tafsir"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain_index"
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
embeddings = SentenceTransformerEmbeddings(model_name="BAAI/bge-m3")
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
MONGODB_ATLAS_CLUSTER_URI,
DB_NAME + "." + COLLECTION_NAME,
embeddings,
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
# FastAPI application setup
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def index_file(filepath):
""" Index each block in a file separated by double newlines for quick search.
Returns a dictionary with key as content and value as block number. """
index = {}
with open(filepath, 'r', encoding='utf-8') as file:
content = file.read() # Read the whole file at once
blocks = content.split("\n\n") # Split the content by double newlines
for block_number, block in enumerate(blocks, 1): # Starting block numbers at 1 for human readability
# Replace single newlines within blocks with space and strip leading/trailing whitespace
formatted_block = ' '.join(block.split('\n')).strip()
index[formatted_block] = block_number
# if(block_number == 100):
# print(formatted_block) # Print the 5th block
return index
def get_text_by_block_number(filepath, block_numbers):
""" Retrieve specific blocks from a file based on block numbers, where each block is separated by '\n\n'. """
blocks_text = []
with open(filepath, 'r', encoding='utf-8') as file:
content = file.read() # Read the whole file at once
blocks = content.split("\n\n") # Split the content by double newlines
for block_number, block in enumerate(blocks, 1): # Starting block numbers at 1 for human readability
if block_number in block_numbers:
# Replace single newlines within blocks with space and strip leading/trailing whitespace
formatted_block = ' '.join(block.split('\n')).strip()
blocks_text.append(formatted_block)
if len(blocks_text) == len(block_numbers): # Stop reading once all required blocks are retrieved
break
return blocks_text
# Existing API endpoints
@app.get("/")
async def read_root():
return {"message": "Welcome to our app"}
# New Query model for the POST request body
class Item(BaseModel):
question: str
EXPECTED_TOKEN = os.getenv("API_TOKEN")
def verify_token(authorization: str = Header(None)):
"""
Dependency to verify the Authorization header contains the correct Bearer token.
"""
# Prefix for bearer token in the Authorization header
prefix = "Bearer "
# Check if the Authorization header is present and correctly formatted
if not authorization or not authorization.startswith(prefix):
raise HTTPException(status_code=401, detail="Unauthorized: Missing or invalid token")
# Extract the token from the Authorization header
token = authorization[len(prefix):]
# Compare the extracted token to the expected token value
if token != EXPECTED_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized: Incorrect token")
# New API endpoint to get an answer using the chain
@app.post("/get_answer")
async def get_answer(item: Item, token: str = Depends(verify_token)):
try:
# Perform the similarity search with the provided question
matching_docs = vector_search.similarity_search(item.question, k=3)
clean_answers = [doc.page_content.replace("\n", " ").strip() for doc in matching_docs]
# Assuming 'search_file.txt' is where we want to search answers
answers_index = index_file('app/quran_tafseer_formatted.txt')
# Collect line numbers based on answers found
line_numbers = [answers_index[answer] for answer in clean_answers if answer in answers_index]
# Assuming 'retrieve_file.txt' is where we retrieve lines based on line numbers
result_text = get_text_by_block_number('app/quran_tafseer.txt', line_numbers)
return {"result_text": result_text}
except Exception as e:
# If there's an error, return a 500 error with the error's details
raise HTTPException(status_code=500, detail=str(e))
# if __name__ == "__main__":
# uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=False)
|