File size: 5,457 Bytes
f6069e3
 
 
b44029a
f6069e3
b44029a
 
f6069e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c9213
 
 
 
 
 
 
f6069e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8e0c1
 
 
 
04cf615
0db81ca
6b8e0c1
 
0db81ca
6b8e0c1
dd37d41
6b8e0c1
 
8781e50
f6069e3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import re
import logging
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import subprocess
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import CTransformers
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import streamlit as st
import uvicorn
from threading import Thread
import requests
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# CORS configuration
#app.add_middleware(
    #CORSMiddleware,
    ##allow_origins=["*"],
    #allow_credentials=True,
    #allow_methods=["*"],
    #allow_headers=["*"],
#)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# FastAPI app
app = FastAPI()


# Load embeddings and vector database
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
try:
    db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
    logger.info("Vector database loaded successfully!")
except Exception as e:
    logger.error(f"Failed to load vector database: {e}")
    raise e

# Load LLM using ctransformers
try:
    llm = CTransformers(
        model="TheBloke/Llama-2-7B-Chat-GGML",
        model_type="llama",
        max_new_tokens=128,
        temperature=0.5,
    )
    logger.info("LLM model loaded successfully!")
except Exception as e:
    logger.error(f"Failed to load LLM model: {e}")
    raise e

# Define custom prompt template
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context: {context}
Question: {question}

Only return the helpful answer below and nothing else.
Helpful answer:
"""
qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])

# Set up RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=db.as_retriever(search_kwargs={"k": 2}),
    return_source_documents=True,
    chain_type_kwargs={"prompt": qa_prompt},
)

class QuestionRequest(BaseModel):
    question: str

class AnswerResponse(BaseModel):
    answer: str

def clean_answer(answer):
    # Remove unnecessary characters and symbols
    cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)

    # Remove repetitive phrases by identifying repeated words or sequences
    cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)

    # Remove any trailing or leading spaces
    cleaned_answer = cleaned_answer.strip()

    # Replace multiple spaces with a single space
    cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)

    # Replace \n with newline character in markdown
    cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)

    # Check for bullet points and replace with markdown syntax
    cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)

    # Check for numbered lists and replace with markdown syntax
    cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)

    # Check for headings and replace with markdown syntax
    cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)

    return cleaned_answer

def format_sources(sources):
    formatted_sources = []
    for source in sources:
        metadata = source.metadata
        page = metadata.get('page', 'Unknown page')
        source_str = f"{metadata.get('source', 'Unknown source')}, page {page}"
        formatted_sources.append(source_str)
    return "\n".join(formatted_sources)

@app.post("/query", response_model=AnswerResponse)
async def query(question_request: QuestionRequest):
    try:
        question = question_request.question
        if not question:
            raise HTTPException(status_code=400, detail="Question is required")

        result = qa_chain({"query": question})
        answer = result.get("result")
        sources = result.get("source_documents")

        if sources:
            formatted_sources = format_sources(sources)
            answer += "\nSources:\n" + formatted_sources
        else:
            answer += "\nNo sources found"

        # Clean up the answer
        cleaned_answer = clean_answer(answer)

        # Return cleaned_answer wrapped in a dictionary
        return {"answer": cleaned_answer}

    except Exception as e:
        logger.error(f"Error processing query: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")


# Define a function to run the Streamlit app
def run_streamlit():
    subprocess.Popen(["streamlit", "run", "frontend.py"])

@app.get("/")
async def serve_streamlit(background_tasks: BackgroundTasks):
    # Start the Streamlit app in a background task
    background_tasks.add_task(run_streamlit)

    # Stream the output of the Streamlit app
    streamlit_output = subprocess.Popen(["streamlit", "run", "frontend.py", "--server.port", "8501"], stdout=subprocess.PIPE)
    return StreamingResponse(streamlit_output.stdout, media_type="text/plain")


#if __name__ == '__main__':
    #uvicorn.run(app, host='0.0.0.0', port=7860)