Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import warnings | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from dotenv import load_dotenv | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.prompts import PromptTemplate | |
from langchain_together import Together | |
import uvicorn | |
# ========================== | |
# Logging Setup | |
# ========================== | |
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# ========================== | |
# Suppress Warnings | |
# ========================== | |
warnings.filterwarnings("ignore") | |
# ========================== | |
# Load Environment Variables | |
# ========================== | |
load_dotenv() | |
TOGETHER_AI_API = os.getenv("TOGETHER_AI") | |
HF_HOME = os.getenv("HF_HOME", "./cache") | |
os.environ["HF_HOME"] = HF_HOME | |
os.makedirs(HF_HOME, exist_ok=True) | |
if not TOGETHER_AI_API: | |
logger.error("TOGETHER_AI_API key is missing. Please set it in the environment variables.") | |
raise RuntimeError("API key not found. Set TOGETHER_AI_API in .env.") | |
# ========================== | |
# App Initialization | |
# ========================== | |
app = FastAPI() | |
# ========================== | |
# Load Existing IPC Vectorstore | |
# ========================== | |
try: | |
embeddings = HuggingFaceEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1", | |
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"}, | |
) | |
logger.info("Embeddings successfully initialized.") | |
# Load the pre-existing IPC vector store directly | |
logger.info("Loading existing IPC vectorstore.") | |
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True) | |
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5}) | |
logger.info("IPC Vectorstore successfully loaded.") | |
except Exception as e: | |
logger.error(f"Error during vectorstore setup: {e}") | |
raise RuntimeError("Initialization failed. Please check your embeddings or vectorstore setup.") | |
# ========================== | |
# Prompt Template (Context-Only) | |
# ========================== | |
prompt_template = """<s>[INST] | |
You are a legal assistant specializing in the Indian Penal Code (IPC). Use only the provided CONTEXT to answer questions. | |
If the information is not found in the CONTEXT, respond with: "I don't have enough information yet." | |
Do not use any outside knowledge. | |
CONTEXT: {context} | |
USER QUERY: {question} | |
RESPONSE: | |
</s>[INST] | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
# ========================== | |
# Initialize Together API | |
# ========================== | |
try: | |
llm = Together( | |
model="mistralai/Mistral-7B-Instruct-v0.2", | |
temperature=0.5, | |
max_tokens=1024, | |
together_api_key=TOGETHER_AI_API, | |
) | |
logger.info("Together API successfully initialized.") | |
except Exception as e: | |
logger.error(f"Error initializing Together API: {e}") | |
raise RuntimeError("Something went wrong with the Together API setup. Please verify your API key.") | |
# ========================== | |
# Chat Processing Function | |
# ========================== | |
def generate_response(user_query: str) -> str: | |
try: | |
# Retrieve relevant documents | |
retrieved_docs = db_retriever.get_relevant_documents(user_query) | |
# Log retrieved documents | |
logger.info(f"User Query: {user_query}") | |
for i, doc in enumerate(retrieved_docs): | |
logger.info(f"Document {i + 1}: {doc.page_content[:500]}...") | |
# Prepare context for the LLM | |
context = "\n\n".join(doc.page_content for doc in retrieved_docs) | |
# Check if context is empty | |
if not context.strip(): | |
return "I don't have enough information yet." | |
# Construct LLM prompt input | |
prompt_input = {"context": context, "question": user_query} | |
logger.debug(f"Payload sent to LLM: {prompt_input}") | |
# Generate response using the LLM | |
response = llm(prompt.format(**prompt_input)) | |
# Check if response is empty | |
if not response.strip(): | |
return "I don't have enough information yet." | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return "An error occurred while generating the response." | |
# ========================== | |
# FastAPI Models and Endpoints | |
# ========================== | |
class ChatRequest(BaseModel): | |
question: str | |
class ChatResponse(BaseModel): | |
answer: str | |
async def root(): | |
return { | |
"message": "Welcome to the Legal Chatbot! Ask me questions about the Indian Penal Code (IPC)." | |
} | |
async def chat(request: ChatRequest): | |
try: | |
logger.debug(f"User question received: {request.question}") | |
answer = generate_response(request.question) | |
logger.debug(f"Chatbot response: {answer}") | |
return ChatResponse(answer=answer) | |
except Exception as e: | |
logger.error(f"Error processing chat request: {e}") | |
raise HTTPException(status_code=500, detail="An internal error occurred. Please try again later.") | |
# ========================== | |
# Run Uvicorn Server | |
# ========================== | |
if __name__ == "__main__": | |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) | |