shahabkahn's picture
Upload 11 files
318f2bb verified
raw
history blame
4.78 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain_community.llms import CTransformers
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import re
import uvicorn
import logging
app = FastAPI()
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load embeddings and vector database
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
try:
db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
logger.info("Vector database loaded successfully!")
except Exception as e:
logger.error(f"Failed to load vector database: {e}")
raise e
# Load LLM
try:
llm = CTransformers(
model="llama-2-7b-chat.ggmlv3.q4_0.bin",
model_type="llama",
max_new_tokens=128,
temperature=0.5,
)
logger.info("LLM model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load LLM model: {e}")
raise e
# Define custom prompt template
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
# Set up RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 2}),
return_source_documents=True,
chain_type_kwargs={"prompt": qa_prompt},
)
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
def clean_answer(answer):
# Remove unnecessary characters and symbols
cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
# Remove repetitive phrases by identifying repeated words or sequences
cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
# Remove any trailing or leading spaces
cleaned_answer = cleaned_answer.strip()
# Replace multiple spaces with a single space
cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
# Replace \n with newline character in markdown
cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
# Check for bullet points and replace with markdown syntax
cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
# Check for numbered lists and replace with markdown syntax
cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
# Check for headings and replace with markdown syntax
cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
return cleaned_answer
def format_sources(sources):
formatted_sources = []
for source in sources:
metadata = source.metadata
page = metadata.get('page', 'Unknown page')
source_str = f"{metadata.get('source', 'Unknown source')}, page {page}"
formatted_sources.append(source_str)
return "\n".join(formatted_sources)
@app.post("/query", response_model=AnswerResponse)
async def query(question_request: QuestionRequest):
try:
question = question_request.question
if not question:
raise HTTPException(status_code=400, detail="Question is required")
result = qa_chain({"query": question})
answer = result.get("result")
sources = result.get("source_documents")
if sources:
formatted_sources = format_sources(sources)
answer += "\nSources:\n" + formatted_sources
else:
answer += "\nNo sources found"
# Clean up the answer
cleaned_answer = clean_answer(answer)
# Return cleaned_answer wrapped in a dictionary
return {"answer": cleaned_answer}
except Exception as e:
logger.error(f"Error processing query: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8000)