File size: 5,020 Bytes
766d2b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ff65c
766d2b2
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ff65c
 
766d2b2
b7ff65c
766d2b2
 
b7ff65c
 
 
766d2b2
b7ff65c
766d2b2
 
 
b7ff65c
 
 
 
 
 
 
 
 
766d2b2
b7ff65c
766d2b2
 
 
 
 
b7ff65c
 
 
 
 
 
 
 
 
 
 
766d2b2
 
 
b7ff65c
 
 
 
 
 
766d2b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ff65c
 
 
6c47e81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import logging
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from dotenv import load_dotenv
import warnings
import uvicorn

# Logging configuration
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.debug("Starting FastAPI app...")

# Suppress warnings
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Load environment variables
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
HF_HOME = os.getenv("HF_HOME", "./cache")

# Set cache directory for Hugging Face
os.environ["HF_HOME"] = HF_HOME

# Ensure HF_HOME exists and is writable
if not os.path.exists(HF_HOME):
    os.makedirs(HF_HOME, exist_ok=True)

# Validate environment variables
if not TOGETHER_AI_API:
    raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")

# Initialize embeddings
try:
    embeddings = HuggingFaceEmbeddings(
        model_name="nomic-ai/nomic-embed-text-v1",
        model_kwargs={"trust_remote_code": True},
    )
except Exception as e:
    logger.error(f"Error loading embeddings: {e}")
    raise RuntimeError("Error initializing HuggingFaceEmbeddings.")

# Ensure FAISS vectorstore is loaded or created
try:
    db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
    db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
except Exception as e:
    logger.error(f"Error loading FAISS vectorstore: {e}")
    # If not found, create a new vectorstore
    try:
        loader = DirectoryLoader('./data')
        text_splitter = RecursiveCharacterTextSplitter()
        documents = text_splitter.split_documents(loader.load())
        db = FAISS.from_documents(documents, embeddings)
        db.save_local("ipc_vector_db")
        db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
    except Exception as inner_e:
        logger.error(f"Error creating FAISS vectorstore: {inner_e}")
        raise RuntimeError("FAISS vectorstore could not be created or loaded.")

# Define the prompt template
prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say "The information is not available in the provided context." 
    CONTEXT: {context}
    CHAT HISTORY: {chat_history}
    QUESTION: {question}
    ANSWER:
    </s>[INST]
    """
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])

# Initialize the Together API
try:
    llm = Together(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        temperature=0.5,
        max_tokens=1024,
        together_api_key=TOGETHER_AI_API,
    )
except Exception as e:
    logger.error(f"Error initializing Together API: {e}")
    raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")

# Initialize conversational retrieval chain
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
    llm=llm,
    memory=memory,
    retriever=db_retriever,
    combine_docs_chain_kwargs={"prompt": prompt},
)

# Initialize FastAPI app
app = FastAPI()

# Define request and response models
class ChatRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    answer: str

# Health check endpoint
@app.get("/")
async def root():
    return {"message": "Hello, World!"}

# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    try:
        # Pass the user question
        result = qa.invoke(input=request.question)
        answer = result.get("answer", "The chatbot could not generate a response.")
        return ChatResponse(answer=answer)
    except Exception as e:
        logger.error(f"Error during chat invocation: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

# Start Uvicorn server if run directly
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0")