chaithanyashaji commited on
Commit
968aa34
·
verified ·
1 Parent(s): c6172d3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -52
main.py CHANGED
@@ -17,28 +17,21 @@ import uvicorn
17
  # Logging configuration
18
  logging.basicConfig(level=logging.DEBUG)
19
  logger = logging.getLogger(__name__)
20
- logger.debug("Starting FastAPI app...")
21
 
22
  # Suppress warnings
23
- warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
24
- warnings.filterwarnings("ignore", category=FutureWarning)
25
- warnings.filterwarnings("ignore", category=DeprecationWarning)
26
 
27
  # Load environment variables
28
  load_dotenv()
29
  TOGETHER_AI_API = os.getenv("TOGETHER_AI")
30
  HF_HOME = os.getenv("HF_HOME", "./cache")
31
-
32
- # Set cache directory for Hugging Face
33
  os.environ["HF_HOME"] = HF_HOME
34
-
35
- # Ensure HF_HOME exists and is writable
36
  if not os.path.exists(HF_HOME):
37
  os.makedirs(HF_HOME, exist_ok=True)
38
 
39
- # Validate environment variables
40
  if not TOGETHER_AI_API:
41
- raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")
42
 
43
  # Initialize embeddings
44
  try:
@@ -48,54 +41,43 @@ try:
48
  )
49
  except Exception as e:
50
  logger.error(f"Error loading embeddings: {e}")
51
- raise RuntimeError("Error initializing HuggingFaceEmbeddings.")
52
 
53
- # Ensure FAISS vectorstore is loaded or created
54
  try:
55
  db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
56
  db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5, "score_threshold": 0.8})
57
  except Exception as e:
58
  logger.error(f"Error loading FAISS vectorstore: {e}")
59
- # If not found, create a new vectorstore
60
- try:
61
- loader = DirectoryLoader('./data')
62
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
63
- documents = text_splitter.split_documents(loader.load())
64
- db = FAISS.from_documents(documents, embeddings)
65
- db.save_local("ipc_vector_db")
66
- db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5, "score_threshold": 0.8})
67
- except Exception as inner_e:
68
- logger.error(f"Error creating FAISS vectorstore: {inner_e}")
69
- raise RuntimeError("FAISS vectorstore could not be created or loaded.")
70
-
71
- # Define the prompt template
72
  prompt_template = """
73
- As a legal chatbot specializing in the Indian Penal Code (IPC), provide precise, fact-based answers to the user’s question based on the provided context.
74
- Respond only if the answer can be derived from the given context; otherwise, say:
75
- "The information is not available in the provided context."
76
- Use plain, professional language in your response.
77
 
78
  CONTEXT: {context}
79
-
80
  CHAT HISTORY: {chat_history}
81
-
82
  QUESTION: {question}
83
-
84
  ANSWER:
85
  """
86
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
87
 
88
- # Initialize the Together API
89
  try:
90
  llm = Together(
91
  model="mistralai/Mistral-7B-Instruct-v0.2",
92
- temperature=0.3, # Lower temperature ensures deterministic answers
93
- max_tokens=512, # Shorter response for focus
94
  together_api_key=TOGETHER_AI_API,
95
  )
96
  except Exception as e:
97
  logger.error(f"Error initializing Together API: {e}")
98
- raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")
99
 
100
  # Initialize conversational retrieval chain
101
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
@@ -106,42 +88,30 @@ qa = ConversationalRetrievalChain.from_llm(
106
  combine_docs_chain_kwargs={"prompt": prompt},
107
  )
108
 
109
- # Initialize FastAPI app
110
  app = FastAPI()
111
 
112
- # Define request and response models
113
  class ChatRequest(BaseModel):
114
  question: str
115
 
116
  class ChatResponse(BaseModel):
117
  answer: str
118
 
119
- # Health check endpoint
120
  @app.get("/")
121
  async def root():
122
- return {"message": "Hello, World!"}
123
 
124
- # Chat endpoint
125
  @app.post("/chat", response_model=ChatResponse)
126
  async def chat(request: ChatRequest):
127
  try:
128
- logger.debug(f"User Question: {request.question}")
129
  result = qa.invoke(input=request.question)
130
- logger.debug(f"Retrieved Context: {result.get('context', '')}")
131
- logger.debug(f"Model Response: {result.get('answer', '')}")
132
-
133
  answer = result.get("answer", "The chatbot could not generate a response.")
134
- confidence_score = result.get("score", 0) # Assuming LLM provides a score
135
-
136
- if confidence_score < 0.7:
137
- answer = "The answer is uncertain. Please consult a professional."
138
-
139
  return ChatResponse(answer=answer)
140
  except Exception as e:
141
  logger.error(f"Error during chat invocation: {e}")
142
  raise HTTPException(status_code=500, detail="Internal server error")
143
 
144
- # Start Uvicorn server if run directly
145
  if __name__ == "__main__":
146
  uvicorn.run("main:app", host="0.0.0.0", port=7860)
147
-
 
17
  # Logging configuration
18
  logging.basicConfig(level=logging.DEBUG)
19
  logger = logging.getLogger(__name__)
20
+ logger.debug("Starting application...")
21
 
22
  # Suppress warnings
23
+ warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False")
 
 
24
 
25
  # Load environment variables
26
  load_dotenv()
27
  TOGETHER_AI_API = os.getenv("TOGETHER_AI")
28
  HF_HOME = os.getenv("HF_HOME", "./cache")
 
 
29
  os.environ["HF_HOME"] = HF_HOME
 
 
30
  if not os.path.exists(HF_HOME):
31
  os.makedirs(HF_HOME, exist_ok=True)
32
 
 
33
  if not TOGETHER_AI_API:
34
+ raise ValueError("TOGETHER_AI_API environment variable is missing.")
35
 
36
  # Initialize embeddings
37
  try:
 
41
  )
42
  except Exception as e:
43
  logger.error(f"Error loading embeddings: {e}")
44
+ raise RuntimeError("Failed to initialize embeddings.")
45
 
46
+ # Load FAISS vectorstore
47
  try:
48
  db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
49
  db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5, "score_threshold": 0.8})
50
  except Exception as e:
51
  logger.error(f"Error loading FAISS vectorstore: {e}")
52
+ loader = DirectoryLoader('./data')
53
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
54
+ documents = text_splitter.split_documents(loader.load())
55
+ db = FAISS.from_documents(documents, embeddings)
56
+ db.save_local("ipc_vector_db")
57
+ db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5, "score_threshold": 0.8})
58
+
59
+ # Define prompt template
 
 
 
 
 
60
  prompt_template = """
61
+ As a legal chatbot specializing in the Indian Penal Code (IPC), provide accurate and concise answers based on the context. Respond only if the answer can be derived from the given context; otherwise, reply: "The information is not available in the provided context." Use professional language.
 
 
 
62
 
63
  CONTEXT: {context}
 
64
  CHAT HISTORY: {chat_history}
 
65
  QUESTION: {question}
 
66
  ANSWER:
67
  """
68
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
69
 
70
+ # Initialize Together API
71
  try:
72
  llm = Together(
73
  model="mistralai/Mistral-7B-Instruct-v0.2",
74
+ temperature=0.3,
75
+ max_tokens=512,
76
  together_api_key=TOGETHER_AI_API,
77
  )
78
  except Exception as e:
79
  logger.error(f"Error initializing Together API: {e}")
80
+ raise RuntimeError("Failed to initialize Together API.")
81
 
82
  # Initialize conversational retrieval chain
83
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
88
  combine_docs_chain_kwargs={"prompt": prompt},
89
  )
90
 
91
+ # FastAPI backend
92
  app = FastAPI()
93
 
 
94
  class ChatRequest(BaseModel):
95
  question: str
96
 
97
  class ChatResponse(BaseModel):
98
  answer: str
99
 
 
100
  @app.get("/")
101
  async def root():
102
+ return {"message": "Legal Chatbot is running."}
103
 
 
104
  @app.post("/chat", response_model=ChatResponse)
105
  async def chat(request: ChatRequest):
106
  try:
107
+ logger.debug(f"User question: {request.question}")
108
  result = qa.invoke(input=request.question)
 
 
 
109
  answer = result.get("answer", "The chatbot could not generate a response.")
 
 
 
 
 
110
  return ChatResponse(answer=answer)
111
  except Exception as e:
112
  logger.error(f"Error during chat invocation: {e}")
113
  raise HTTPException(status_code=500, detail="Internal server error")
114
 
115
+ # Start Uvicorn if run directly
116
  if __name__ == "__main__":
117
  uvicorn.run("main:app", host="0.0.0.0", port=7860)