File size: 5,359 Bytes
766d2b2
c6172d3
766d2b2
 
 
 
 
 
 
 
 
 
 
 
b7ff65c
766d2b2
 
 
 
 
 
 
 
 
 
 
 
 
b7ff65c
 
766d2b2
b7ff65c
766d2b2
 
b7ff65c
 
 
766d2b2
b7ff65c
766d2b2
 
 
b7ff65c
 
 
 
 
 
 
 
 
766d2b2
b7ff65c
766d2b2
 
c6172d3
766d2b2
 
b7ff65c
 
 
c6172d3
b7ff65c
 
 
c6172d3
b7ff65c
 
 
766d2b2
 
c6172d3
 
 
 
 
 
 
 
 
 
 
 
 
 
766d2b2
 
 
 
 
 
c6172d3
 
766d2b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6172d3
766d2b2
c6172d3
 
 
766d2b2
c6172d3
 
 
 
 
766d2b2
 
 
 
b7ff65c
 
 
2f7e47e
6c47e81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
from langchain_community.document_loaders import DirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from dotenv import load_dotenv
import warnings
import uvicorn

# Logging configuration
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.debug("Starting FastAPI app...")

# Suppress warnings
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Load environment variables
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
HF_HOME = os.getenv("HF_HOME", "./cache")

# Set cache directory for Hugging Face
os.environ["HF_HOME"] = HF_HOME

# Ensure HF_HOME exists and is writable
if not os.path.exists(HF_HOME):
    os.makedirs(HF_HOME, exist_ok=True)

# Validate environment variables
if not TOGETHER_AI_API:
    raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")

# Initialize embeddings
try:
    embeddings = HuggingFaceEmbeddings(
        model_name="nomic-ai/nomic-embed-text-v1",
        model_kwargs={"trust_remote_code": True},
    )
except Exception as e:
    logger.error(f"Error loading embeddings: {e}")
    raise RuntimeError("Error initializing HuggingFaceEmbeddings.")

# Ensure FAISS vectorstore is loaded or created
try:
    db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
    db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5, "score_threshold": 0.8})
except Exception as e:
    logger.error(f"Error loading FAISS vectorstore: {e}")
    # If not found, create a new vectorstore
    try:
        loader = DirectoryLoader('./data')
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        documents = text_splitter.split_documents(loader.load())
        db = FAISS.from_documents(documents, embeddings)
        db.save_local("ipc_vector_db")
        db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5, "score_threshold": 0.8})
    except Exception as inner_e:
        logger.error(f"Error creating FAISS vectorstore: {inner_e}")
        raise RuntimeError("FAISS vectorstore could not be created or loaded.")

# Define the prompt template
prompt_template = """
As a legal chatbot specializing in the Indian Penal Code (IPC), provide precise, fact-based answers to the user’s question based on the provided context. 
Respond only if the answer can be derived from the given context; otherwise, say: 
"The information is not available in the provided context."
Use plain, professional language in your response.

CONTEXT: {context}

CHAT HISTORY: {chat_history}

QUESTION: {question}

ANSWER:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])

# Initialize the Together API
try:
    llm = Together(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        temperature=0.3,  # Lower temperature ensures deterministic answers
        max_tokens=512,   # Shorter response for focus
        together_api_key=TOGETHER_AI_API,
    )
except Exception as e:
    logger.error(f"Error initializing Together API: {e}")
    raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")

# Initialize conversational retrieval chain
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
    llm=llm,
    memory=memory,
    retriever=db_retriever,
    combine_docs_chain_kwargs={"prompt": prompt},
)

# Initialize FastAPI app
app = FastAPI()

# Define request and response models
class ChatRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    answer: str

# Health check endpoint
@app.get("/")
async def root():
    return {"message": "Hello, World!"}

# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    try:
        logger.debug(f"User Question: {request.question}")
        result = qa.invoke(input=request.question)
        logger.debug(f"Retrieved Context: {result.get('context', '')}")
        logger.debug(f"Model Response: {result.get('answer', '')}")

        answer = result.get("answer", "The chatbot could not generate a response.")
        confidence_score = result.get("score", 0)  # Assuming LLM provides a score

        if confidence_score < 0.7:
            answer = "The answer is uncertain. Please consult a professional."

        return ChatResponse(answer=answer)
    except Exception as e:
        logger.error(f"Error during chat invocation: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

# Start Uvicorn server if run directly
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=7860)