chaithanyashaji commited on
Commit
766d2b2
·
verified ·
1 Parent(s): edd5cd7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -105
main.py CHANGED
@@ -1,105 +1,113 @@
1
- import logging
2
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
3
- from langchain_huggingface import HuggingFaceEmbeddings
4
- from sentence_transformers import SentenceTransformer
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.vectorstores import FAISS
7
- from langchain.prompts import PromptTemplate
8
- from langchain_together import Together
9
- from langchain.memory import ConversationBufferMemory
10
- from langchain.chains import ConversationalRetrievalChain
11
- from fastapi import FastAPI, HTTPException
12
- from pydantic import BaseModel
13
- import os
14
- from dotenv import load_dotenv
15
- import warnings
16
-
17
- # Logging configuration
18
- logging.basicConfig(level=logging.DEBUG)
19
- logger = logging.getLogger(__name__)
20
- logger.debug("Starting FastAPI app...")
21
-
22
- # Suppress warnings
23
- warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
24
- warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
25
- warnings.filterwarnings("ignore", category=FutureWarning)
26
- warnings.filterwarnings("ignore", category=DeprecationWarning)
27
-
28
- # Load environment variables
29
- load_dotenv()
30
- TOGETHER_AI_API = os.getenv("TOGETHER_AI")
31
-
32
- if not TOGETHER_AI_API:
33
- raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")
34
-
35
- # Initialize embeddings and vectorstore
36
- embeddings = HuggingFaceEmbeddings(
37
- model_name="nomic-ai/nomic-embed-text-v1",
38
- model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
39
- )
40
-
41
- # Ensure FAISS vectorstore is loaded properly
42
- try:
43
- db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
44
- db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
45
- except Exception as e:
46
- logger.error(f"Error loading FAISS vectorstore: {e}")
47
- raise RuntimeError("FAISS vectorstore could not be loaded. Ensure the vector database exists.")
48
-
49
- # Define the prompt template
50
- prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say "The information is not available in the provided context."
51
- CONTEXT: {context}
52
- CHAT HISTORY: {chat_history}
53
- QUESTION: {question}
54
- ANSWER:
55
- </s>[INST]
56
- """
57
- prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
58
-
59
- # Initialize the Together API
60
- try:
61
- llm = Together(
62
- model="mistralai/Mistral-7B-Instruct-v0.2",
63
- temperature=0.5,
64
- max_tokens=1024,
65
- together_api_key=TOGETHER_AI_API,
66
- )
67
- except Exception as e:
68
- logger.error(f"Error initializing Together API: {e}")
69
- raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")
70
-
71
- # Initialize conversational retrieval chain
72
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
73
- qa = ConversationalRetrievalChain.from_llm(
74
- llm=llm,
75
- memory=memory,
76
- retriever=db_retriever,
77
- combine_docs_chain_kwargs={"prompt": prompt},
78
- )
79
-
80
- # Initialize FastAPI app
81
- app = FastAPI()
82
-
83
- # Define request and response models
84
- class ChatRequest(BaseModel):
85
- question: str
86
-
87
- class ChatResponse(BaseModel):
88
- answer: str
89
-
90
- # Health check endpoint
91
- @app.get("/")
92
- async def root():
93
- return {"message": "Hello, World!"}
94
-
95
- # Chat endpoint
96
- @app.post("/chat", response_model=ChatResponse)
97
- async def chat(request: ChatRequest):
98
- try:
99
- # Pass the user question
100
- result = qa.invoke(input=request.question)
101
- answer = result.get("answer", "The chatbot could not generate a response.")
102
- return ChatResponse(answer=answer)
103
- except Exception as e:
104
- logger.error(f"Error during chat invocation: {e}")
105
- raise HTTPException(status_code=500, detail="Internal server error")
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+ from sentence_transformers import SentenceTransformer
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_together import Together
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
+ import os
14
+ from dotenv import load_dotenv
15
+ import warnings
16
+
17
+ # Logging configuration
18
+ logging.basicConfig(level=logging.DEBUG)
19
+ logger = logging.getLogger(__name__)
20
+ logger.debug("Starting FastAPI app...")
21
+
22
+ # Suppress warnings
23
+ warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
24
+ warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
25
+ warnings.filterwarnings("ignore", category=FutureWarning)
26
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
27
+
28
+ # Load environment variables
29
+ load_dotenv()
30
+
31
+ # Set HF_HOME for Hugging Face cache directory
32
+ HF_HOME = os.getenv("HF_HOME", "/tmp/cache")
33
+ os.environ["HF_HOME"] = HF_HOME
34
+
35
+ # Ensure the directory exists
36
+ os.makedirs(HF_HOME, exist_ok=True)
37
+
38
+ # Validate Together API environment variable
39
+ TOGETHER_AI_API = os.getenv("TOGETHER_AI")
40
+ if not TOGETHER_AI_API:
41
+ raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")
42
+
43
+ # Initialize embeddings and vectorstore
44
+ embeddings = HuggingFaceEmbeddings(
45
+ model_name="nomic-ai/nomic-embed-text-v1",
46
+ model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
47
+ )
48
+
49
+ # Ensure FAISS vectorstore is loaded properly
50
+ try:
51
+ db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
52
+ db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
53
+ except Exception as e:
54
+ logger.error(f"Error loading FAISS vectorstore: {e}")
55
+ raise RuntimeError("FAISS vectorstore could not be loaded. Ensure the vector database exists.")
56
+
57
+ # Define the prompt template
58
+ prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say "The information is not available in the provided context."
59
+ CONTEXT: {context}
60
+ CHAT HISTORY: {chat_history}
61
+ QUESTION: {question}
62
+ ANSWER:
63
+ </s>[INST]
64
+ """
65
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
66
+
67
+ # Initialize the Together API
68
+ try:
69
+ llm = Together(
70
+ model="mistralai/Mistral-7B-Instruct-v0.2",
71
+ temperature=0.5,
72
+ max_tokens=1024,
73
+ together_api_key=TOGETHER_AI_API,
74
+ )
75
+ except Exception as e:
76
+ logger.error(f"Error initializing Together API: {e}")
77
+ raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")
78
+
79
+ # Initialize conversational retrieval chain
80
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
81
+ qa = ConversationalRetrievalChain.from_llm(
82
+ llm=llm,
83
+ memory=memory,
84
+ retriever=db_retriever,
85
+ combine_docs_chain_kwargs={"prompt": prompt},
86
+ )
87
+
88
+ # Initialize FastAPI app
89
+ app = FastAPI()
90
+
91
+ # Define request and response models
92
+ class ChatRequest(BaseModel):
93
+ question: str
94
+
95
+ class ChatResponse(BaseModel):
96
+ answer: str
97
+
98
+ # Health check endpoint
99
+ @app.get("/")
100
+ async def root():
101
+ return {"message": "Hello, World!"}
102
+
103
+ # Chat endpoint
104
+ @app.post("/chat", response_model=ChatResponse)
105
+ async def chat(request: ChatRequest):
106
+ try:
107
+ # Pass the user question
108
+ result = qa.invoke(input=request.question)
109
+ answer = result.get("answer", "The chatbot could not generate a response.")
110
+ return ChatResponse(answer=answer)
111
+ except Exception as e:
112
+ logger.error(f"Error during chat invocation: {e}")
113
+ raise HTTPException(status_code=500, detail="Internal server error")