shahabkahn commited on
Commit
f6069e3
·
verified ·
1 Parent(s): 318f2bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -142
app.py CHANGED
@@ -1,142 +1,153 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from langchain.chains import RetrievalQA
5
- from langchain_community.llms import CTransformers
6
- from langchain.prompts import PromptTemplate
7
- from langchain_community.vectorstores import FAISS
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
- import re
10
- import uvicorn
11
- import logging
12
-
13
- app = FastAPI()
14
-
15
- # CORS configuration
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"],
19
- allow_credentials=True,
20
- allow_methods=["*"],
21
- allow_headers=["*"],
22
- )
23
-
24
- # Set up logging
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
- # Load embeddings and vector database
29
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
30
- try:
31
- db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
32
- logger.info("Vector database loaded successfully!")
33
- except Exception as e:
34
- logger.error(f"Failed to load vector database: {e}")
35
- raise e
36
-
37
- # Load LLM
38
- try:
39
- llm = CTransformers(
40
- model="llama-2-7b-chat.ggmlv3.q4_0.bin",
41
- model_type="llama",
42
- max_new_tokens=128,
43
- temperature=0.5,
44
- )
45
- logger.info("LLM model loaded successfully!")
46
- except Exception as e:
47
- logger.error(f"Failed to load LLM model: {e}")
48
- raise e
49
-
50
- # Define custom prompt template
51
- custom_prompt_template = """Use the following pieces of information to answer the user's question.
52
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
53
-
54
- Context: {context}
55
- Question: {question}
56
-
57
- Only return the helpful answer below and nothing else.
58
- Helpful answer:
59
- """
60
- qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
61
-
62
- # Set up RetrievalQA chain
63
- qa_chain = RetrievalQA.from_chain_type(
64
- llm=llm,
65
- chain_type="stuff",
66
- retriever=db.as_retriever(search_kwargs={"k": 2}),
67
- return_source_documents=True,
68
- chain_type_kwargs={"prompt": qa_prompt},
69
- )
70
-
71
- class QuestionRequest(BaseModel):
72
- question: str
73
-
74
- class AnswerResponse(BaseModel):
75
- answer: str
76
-
77
- def clean_answer(answer):
78
- # Remove unnecessary characters and symbols
79
- cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
80
-
81
- # Remove repetitive phrases by identifying repeated words or sequences
82
- cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
83
-
84
- # Remove any trailing or leading spaces
85
- cleaned_answer = cleaned_answer.strip()
86
-
87
- # Replace multiple spaces with a single space
88
- cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
89
-
90
- # Replace \n with newline character in markdown
91
- cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
92
-
93
- # Check for bullet points and replace with markdown syntax
94
- cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
95
-
96
- # Check for numbered lists and replace with markdown syntax
97
- cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
98
-
99
- # Check for headings and replace with markdown syntax
100
- cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
101
-
102
- return cleaned_answer
103
-
104
- def format_sources(sources):
105
- formatted_sources = []
106
- for source in sources:
107
- metadata = source.metadata
108
- page = metadata.get('page', 'Unknown page')
109
- source_str = f"{metadata.get('source', 'Unknown source')}, page {page}"
110
- formatted_sources.append(source_str)
111
- return "\n".join(formatted_sources)
112
-
113
- @app.post("/query", response_model=AnswerResponse)
114
- async def query(question_request: QuestionRequest):
115
- try:
116
- question = question_request.question
117
- if not question:
118
- raise HTTPException(status_code=400, detail="Question is required")
119
-
120
- result = qa_chain({"query": question})
121
- answer = result.get("result")
122
- sources = result.get("source_documents")
123
-
124
- if sources:
125
- formatted_sources = format_sources(sources)
126
- answer += "\nSources:\n" + formatted_sources
127
- else:
128
- answer += "\nNo sources found"
129
-
130
- # Clean up the answer
131
- cleaned_answer = clean_answer(answer)
132
-
133
- # Return cleaned_answer wrapped in a dictionary
134
- return {"answer": cleaned_answer}
135
-
136
- except Exception as e:
137
- logger.error(f"Error processing query: {e}")
138
- raise HTTPException(status_code=500, detail="Internal Server Error")
139
-
140
-
141
- if __name__ == '__main__':
142
- uvicorn.run(app, host='0.0.0.0', port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_community.llms import CTransformers
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ import streamlit as st
13
+ import uvicorn
14
+ from threading import Thread
15
+ import requests
16
+ from dotenv import load_dotenv
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # CORS configuration
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ # Set up logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # FastAPI app
36
+ app = FastAPI()
37
+
38
+
39
+ # Load embeddings and vector database
40
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
41
+ try:
42
+ db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
43
+ logger.info("Vector database loaded successfully!")
44
+ except Exception as e:
45
+ logger.error(f"Failed to load vector database: {e}")
46
+ raise e
47
+
48
+ # Load LLM using ctransformers
49
+ try:
50
+ llm = CTransformers(
51
+ model="TheBloke/Llama-2-7B-Chat-GGML",
52
+ model_type="llama",
53
+ max_new_tokens=128,
54
+ temperature=0.5,
55
+ )
56
+ logger.info("LLM model loaded successfully!")
57
+ except Exception as e:
58
+ logger.error(f"Failed to load LLM model: {e}")
59
+ raise e
60
+
61
+ # Define custom prompt template
62
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
63
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
64
+
65
+ Context: {context}
66
+ Question: {question}
67
+
68
+ Only return the helpful answer below and nothing else.
69
+ Helpful answer:
70
+ """
71
+ qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
72
+
73
+ # Set up RetrievalQA chain
74
+ qa_chain = RetrievalQA.from_chain_type(
75
+ llm=llm,
76
+ chain_type="stuff",
77
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
78
+ return_source_documents=True,
79
+ chain_type_kwargs={"prompt": qa_prompt},
80
+ )
81
+
82
+ class QuestionRequest(BaseModel):
83
+ question: str
84
+
85
+ class AnswerResponse(BaseModel):
86
+ answer: str
87
+
88
+ def clean_answer(answer):
89
+ # Remove unnecessary characters and symbols
90
+ cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
91
+
92
+ # Remove repetitive phrases by identifying repeated words or sequences
93
+ cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
94
+
95
+ # Remove any trailing or leading spaces
96
+ cleaned_answer = cleaned_answer.strip()
97
+
98
+ # Replace multiple spaces with a single space
99
+ cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
100
+
101
+ # Replace \n with newline character in markdown
102
+ cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
103
+
104
+ # Check for bullet points and replace with markdown syntax
105
+ cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
106
+
107
+ # Check for numbered lists and replace with markdown syntax
108
+ cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
109
+
110
+ # Check for headings and replace with markdown syntax
111
+ cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
112
+
113
+ return cleaned_answer
114
+
115
+ def format_sources(sources):
116
+ formatted_sources = []
117
+ for source in sources:
118
+ metadata = source.metadata
119
+ page = metadata.get('page', 'Unknown page')
120
+ source_str = f"{metadata.get('source', 'Unknown source')}, page {page}"
121
+ formatted_sources.append(source_str)
122
+ return "\n".join(formatted_sources)
123
+
124
+ @app.post("/query", response_model=AnswerResponse)
125
+ async def query(question_request: QuestionRequest):
126
+ try:
127
+ question = question_request.question
128
+ if not question:
129
+ raise HTTPException(status_code=400, detail="Question is required")
130
+
131
+ result = qa_chain({"query": question})
132
+ answer = result.get("result")
133
+ sources = result.get("source_documents")
134
+
135
+ if sources:
136
+ formatted_sources = format_sources(sources)
137
+ answer += "\nSources:\n" + formatted_sources
138
+ else:
139
+ answer += "\nNo sources found"
140
+
141
+ # Clean up the answer
142
+ cleaned_answer = clean_answer(answer)
143
+
144
+ # Return cleaned_answer wrapped in a dictionary
145
+ return {"answer": cleaned_answer}
146
+
147
+ except Exception as e:
148
+ logger.error(f"Error processing query: {e}")
149
+ raise HTTPException(status_code=500, detail="Internal Server Error")
150
+
151
+
152
+ #if __name__ == '__main__':
153
+ #uvicorn.run(app, host='0.0.0.0', port=7860)