Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langchain_community.document_loaders import DirectoryLoader
|
3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -7,59 +13,73 @@ from langchain.prompts import PromptTemplate
|
|
7 |
from langchain_together import Together
|
8 |
from langchain.memory import ConversationBufferMemory
|
9 |
from langchain.chains import ConversationalRetrievalChain
|
10 |
-
from fastapi import FastAPI, HTTPException
|
11 |
-
from pydantic import BaseModel
|
12 |
-
import os
|
13 |
-
from dotenv import load_dotenv
|
14 |
-
import warnings
|
15 |
-
import uvicorn
|
16 |
|
17 |
-
#
|
|
|
|
|
18 |
logging.basicConfig(level=logging.DEBUG)
|
19 |
-
logger = logging.getLogger(
|
20 |
-
logger.debug("
|
21 |
|
22 |
-
#
|
|
|
|
|
23 |
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False")
|
24 |
|
25 |
-
#
|
|
|
|
|
26 |
load_dotenv()
|
27 |
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
|
28 |
HF_HOME = os.getenv("HF_HOME", "./cache")
|
29 |
os.environ["HF_HOME"] = HF_HOME
|
30 |
-
if not os.path.exists(HF_HOME):
|
31 |
-
os.makedirs(HF_HOME, exist_ok=True)
|
32 |
|
|
|
|
|
|
|
|
|
33 |
if not TOGETHER_AI_API:
|
34 |
-
raise ValueError("TOGETHER_AI_API environment variable is missing.")
|
35 |
|
36 |
-
#
|
|
|
|
|
37 |
try:
|
38 |
embeddings = HuggingFaceEmbeddings(
|
39 |
model_name="nomic-ai/nomic-embed-text-v1",
|
40 |
-
model_kwargs={"trust_remote_code": True,"revision":"289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
|
41 |
)
|
|
|
42 |
except Exception as e:
|
43 |
-
logger.error(f"Error
|
44 |
-
raise RuntimeError("
|
45 |
-
|
46 |
-
# Load FAISS vectorstore
|
47 |
-
|
48 |
-
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
|
49 |
-
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max-length":512})
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
#
|
53 |
-
|
|
|
|
|
54 |
CONTEXT: {context}
|
55 |
CHAT HISTORY: {chat_history}
|
56 |
QUESTION: {question}
|
57 |
ANSWER:
|
58 |
-
</s>[INST]
|
59 |
-
|
60 |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
|
61 |
|
|
|
62 |
# Initialize Together API
|
|
|
63 |
try:
|
64 |
llm = Together(
|
65 |
model="mistralai/Mistral-7B-Instruct-v0.2",
|
@@ -67,11 +87,14 @@ try:
|
|
67 |
max_tokens=1024,
|
68 |
together_api_key=TOGETHER_AI_API,
|
69 |
)
|
|
|
70 |
except Exception as e:
|
71 |
logger.error(f"Error initializing Together API: {e}")
|
72 |
-
raise RuntimeError("
|
73 |
|
74 |
-
#
|
|
|
|
|
75 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
76 |
qa = ConversationalRetrievalChain.from_llm(
|
77 |
llm=llm,
|
@@ -79,8 +102,11 @@ qa = ConversationalRetrievalChain.from_llm(
|
|
79 |
retriever=db_retriever,
|
80 |
combine_docs_chain_kwargs={"prompt": prompt},
|
81 |
)
|
|
|
82 |
|
83 |
-
#
|
|
|
|
|
84 |
app = FastAPI()
|
85 |
|
86 |
class ChatRequest(BaseModel):
|
@@ -91,19 +117,24 @@ class ChatResponse(BaseModel):
|
|
91 |
|
92 |
@app.get("/")
|
93 |
async def root():
|
94 |
-
return {"message": "Legal Chatbot
|
95 |
|
96 |
@app.post("/chat", response_model=ChatResponse)
|
97 |
async def chat(request: ChatRequest):
|
98 |
try:
|
99 |
-
logger.debug(f"
|
100 |
result = qa.invoke(input=request.question)
|
101 |
-
answer = result.get("answer", "
|
102 |
return ChatResponse(answer=answer)
|
103 |
except Exception as e:
|
104 |
logger.error(f"Error during chat invocation: {e}")
|
105 |
-
raise HTTPException(status_code=500, detail="
|
106 |
|
107 |
-
#
|
|
|
|
|
108 |
if __name__ == "__main__":
|
109 |
uvicorn.run("main:app", host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from fastapi import FastAPI, HTTPException
|
6 |
+
from pydantic import BaseModel
|
7 |
+
import uvicorn
|
8 |
from langchain_community.document_loaders import DirectoryLoader
|
9 |
from langchain_huggingface import HuggingFaceEmbeddings
|
10 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
13 |
from langchain_together import Together
|
14 |
from langchain.memory import ConversationBufferMemory
|
15 |
from langchain.chains import ConversationalRetrievalChain
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# ==========================
|
18 |
+
# Logging Configuration
|
19 |
+
# ==========================
|
20 |
logging.basicConfig(level=logging.DEBUG)
|
21 |
+
logger = logging.getLogger("LegalChatbot")
|
22 |
+
logger.debug("Initializing Legal Chatbot application...")
|
23 |
|
24 |
+
# ==========================
|
25 |
+
# Suppress Warnings
|
26 |
+
# ==========================
|
27 |
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False")
|
28 |
|
29 |
+
# ==========================
|
30 |
+
# Load Environment Variables
|
31 |
+
# ==========================
|
32 |
load_dotenv()
|
33 |
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
|
34 |
HF_HOME = os.getenv("HF_HOME", "./cache")
|
35 |
os.environ["HF_HOME"] = HF_HOME
|
|
|
|
|
36 |
|
37 |
+
# Ensure the HF_HOME directory exists
|
38 |
+
os.makedirs(HF_HOME, exist_ok=True)
|
39 |
+
|
40 |
+
# Validate required environment variables
|
41 |
if not TOGETHER_AI_API:
|
42 |
+
raise ValueError("The TOGETHER_AI_API environment variable is missing. Please set it in your .env file.")
|
43 |
|
44 |
+
# ==========================
|
45 |
+
# Initialize Embeddings
|
46 |
+
# ==========================
|
47 |
try:
|
48 |
embeddings = HuggingFaceEmbeddings(
|
49 |
model_name="nomic-ai/nomic-embed-text-v1",
|
50 |
+
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
|
51 |
)
|
52 |
+
logger.info("Embeddings successfully initialized.")
|
53 |
except Exception as e:
|
54 |
+
logger.error(f"Error initializing embeddings: {e}")
|
55 |
+
raise RuntimeError("Oops! Something went wrong while setting up embeddings. Please check the configuration and try again.")
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
# ==========================
|
58 |
+
# Load FAISS Vectorstore
|
59 |
+
# ==========================
|
60 |
+
try:
|
61 |
+
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
|
62 |
+
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max-length": 512})
|
63 |
+
logger.info("Vectorstore successfully loaded.")
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Error loading FAISS vectorstore: {e}")
|
66 |
+
raise RuntimeError("We couldn't load the vector database. Please ensure the database file is available and try again.")
|
67 |
|
68 |
+
# ==========================
|
69 |
+
# Define Prompt Template
|
70 |
+
# ==========================
|
71 |
+
prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say \"The information is not available in the provided context.\"
|
72 |
CONTEXT: {context}
|
73 |
CHAT HISTORY: {chat_history}
|
74 |
QUESTION: {question}
|
75 |
ANSWER:
|
76 |
+
</s>[INST]"""
|
77 |
+
|
78 |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
|
79 |
|
80 |
+
# ==========================
|
81 |
# Initialize Together API
|
82 |
+
# ==========================
|
83 |
try:
|
84 |
llm = Together(
|
85 |
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
87 |
max_tokens=1024,
|
88 |
together_api_key=TOGETHER_AI_API,
|
89 |
)
|
90 |
+
logger.info("Together API successfully initialized.")
|
91 |
except Exception as e:
|
92 |
logger.error(f"Error initializing Together API: {e}")
|
93 |
+
raise RuntimeError("Something went wrong with the Together API setup. Please verify your API key and configuration.")
|
94 |
|
95 |
+
# ==========================
|
96 |
+
# Conversational Retrieval Chain
|
97 |
+
# ==========================
|
98 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
99 |
qa = ConversationalRetrievalChain.from_llm(
|
100 |
llm=llm,
|
|
|
102 |
retriever=db_retriever,
|
103 |
combine_docs_chain_kwargs={"prompt": prompt},
|
104 |
)
|
105 |
+
logger.info("Conversational Retrieval Chain initialized.")
|
106 |
|
107 |
+
# ==========================
|
108 |
+
# FastAPI Backend
|
109 |
+
# ==========================
|
110 |
app = FastAPI()
|
111 |
|
112 |
class ChatRequest(BaseModel):
|
|
|
117 |
|
118 |
@app.get("/")
|
119 |
async def root():
|
120 |
+
return {"message": "Hello! Welcome to the Legal Chatbot. I'm here to assist you with your legal queries related to the Indian Penal Code. How can I help you today?"}
|
121 |
|
122 |
@app.post("/chat", response_model=ChatResponse)
|
123 |
async def chat(request: ChatRequest):
|
124 |
try:
|
125 |
+
logger.debug(f"Received user question: {request.question}")
|
126 |
result = qa.invoke(input=request.question)
|
127 |
+
answer = result.get("answer", "I'm sorry, but I couldn't generate a response to your query. Please try rephrasing or providing more details.")
|
128 |
return ChatResponse(answer=answer)
|
129 |
except Exception as e:
|
130 |
logger.error(f"Error during chat invocation: {e}")
|
131 |
+
raise HTTPException(status_code=500, detail="Oops! Something went wrong on our end. Please try again later.")
|
132 |
|
133 |
+
# ==========================
|
134 |
+
# Run Uvicorn Server
|
135 |
+
# ==========================
|
136 |
if __name__ == "__main__":
|
137 |
uvicorn.run("main:app", host="0.0.0.0", port=7860)
|
138 |
+
|
139 |
+
|
140 |
+
|