shahabkahn commited on
Commit
fe534eb
·
verified ·
1 Parent(s): 5fb764b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -24
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import re
3
  import logging
 
4
  from fastapi import FastAPI, HTTPException
5
  from fastapi.responses import RedirectResponse
6
  from pydantic import BaseModel
@@ -9,21 +10,18 @@ from langchain.prompts import PromptTemplate
9
  from langchain_community.llms import CTransformers
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
- import subprocess
13
  from dotenv import load_dotenv
14
 
15
  # Load environment variables
16
  load_dotenv()
17
 
18
- # Set up logging
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- # FastAPI app
23
  app = FastAPI()
24
 
25
- # Load embeddings and vector database
26
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
 
27
  try:
28
  db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
29
  logger.info("Vector database loaded successfully!")
@@ -31,7 +29,6 @@ except Exception as e:
31
  logger.error(f"Failed to load vector database: {e}")
32
  raise e
33
 
34
- # Load LLM using ctransformers
35
  try:
36
  llm = CTransformers(
37
  model="TheBloke/Llama-2-7B-Chat-GGML",
@@ -44,7 +41,6 @@ except Exception as e:
44
  logger.error(f"Failed to load LLM model: {e}")
45
  raise e
46
 
47
- # Define custom prompt template
48
  custom_prompt_template = """Use the following pieces of information to answer the user's question.
49
  If you don't know the answer, just say that you don't know, don't try to make up an answer.
50
 
@@ -56,7 +52,6 @@ Helpful answer:
56
  """
57
  qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
58
 
59
- # Set up RetrievalQA chain
60
  qa_chain = RetrievalQA.from_chain_type(
61
  llm=llm,
62
  chain_type="stuff",
@@ -72,23 +67,14 @@ class AnswerResponse(BaseModel):
72
  answer: str
73
 
74
  def clean_answer(answer):
75
- # Remove unnecessary characters and symbols
76
  cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
77
- # Remove repetitive phrases by identifying repeated words or sequences
78
  cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
79
- # Remove any trailing or leading spaces
80
  cleaned_answer = cleaned_answer.strip()
81
- # Replace multiple spaces with a single space
82
  cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
83
- # Replace \n with newline character in markdown
84
  cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
85
- # Check for bullet points and replace with markdown syntax
86
  cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
87
- # Check for numbered lists and replace with markdown syntax
88
  cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
89
- # Check for headings and replace with markdown syntax
90
  cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
91
-
92
  return cleaned_answer
93
 
94
  def format_sources(sources):
@@ -107,7 +93,8 @@ async def query(question_request: QuestionRequest):
107
  if not question:
108
  raise HTTPException(status_code=400, detail="Question is required")
109
 
110
- result = qa_chain({"query": question})
 
111
  answer = result.get("result")
112
  sources = result.get("source_documents")
113
 
@@ -117,19 +104,16 @@ async def query(question_request: QuestionRequest):
117
  else:
118
  answer += "\nNo sources found"
119
 
120
- # Clean up the answer
121
  cleaned_answer = clean_answer(answer)
122
-
123
  return {"answer": cleaned_answer}
124
-
125
  except Exception as e:
126
  logger.error(f"Error processing query: {e}")
127
  raise HTTPException(status_code=500, detail="Internal Server Error")
128
 
 
 
 
129
 
130
  @app.get("/")
131
  async def root():
132
- return RedirectResponse(url="/docs")
133
-
134
- #if __name__ == '__main__':
135
- #uvicorn.run(app, host='0.0.0.0', port=7860)
 
1
  import os
2
  import re
3
  import logging
4
+ import asyncio
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.responses import RedirectResponse
7
  from pydantic import BaseModel
 
10
  from langchain_community.llms import CTransformers
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
13
  from dotenv import load_dotenv
14
 
15
  # Load environment variables
16
  load_dotenv()
17
 
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
 
21
  app = FastAPI()
22
 
 
23
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
24
+
25
  try:
26
  db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
27
  logger.info("Vector database loaded successfully!")
 
29
  logger.error(f"Failed to load vector database: {e}")
30
  raise e
31
 
 
32
  try:
33
  llm = CTransformers(
34
  model="TheBloke/Llama-2-7B-Chat-GGML",
 
41
  logger.error(f"Failed to load LLM model: {e}")
42
  raise e
43
 
 
44
  custom_prompt_template = """Use the following pieces of information to answer the user's question.
45
  If you don't know the answer, just say that you don't know, don't try to make up an answer.
46
 
 
52
  """
53
  qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
54
 
 
55
  qa_chain = RetrievalQA.from_chain_type(
56
  llm=llm,
57
  chain_type="stuff",
 
67
  answer: str
68
 
69
  def clean_answer(answer):
 
70
  cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
 
71
  cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
 
72
  cleaned_answer = cleaned_answer.strip()
 
73
  cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
 
74
  cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
 
75
  cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
 
76
  cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
 
77
  cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
 
78
  return cleaned_answer
79
 
80
  def format_sources(sources):
 
93
  if not question:
94
  raise HTTPException(status_code=400, detail="Question is required")
95
 
96
+ loop = asyncio.get_event_loop()
97
+ result = await loop.run_in_executor(None, qa_chain, {"query": question})
98
  answer = result.get("result")
99
  sources = result.get("source_documents")
100
 
 
104
  else:
105
  answer += "\nNo sources found"
106
 
 
107
  cleaned_answer = clean_answer(answer)
 
108
  return {"answer": cleaned_answer}
 
109
  except Exception as e:
110
  logger.error(f"Error processing query: {e}")
111
  raise HTTPException(status_code=500, detail="Internal Server Error")
112
 
113
+ @app.on_event("startup")
114
+ async def startup_event():
115
+ subprocess.Popen(["streamlit", "run", "frontend.py", "--server.port", "8501"])
116
 
117
  @app.get("/")
118
  async def root():
119
+ return RedirectResponse(url="http://localhost:8501")