Spaces:
Sleeping
Sleeping
File size: 5,188 Bytes
766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c 766d2b2 b7ff65c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import logging
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from dotenv import load_dotenv
import warnings
import uvicorn
# Logging configuration
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.debug("Starting FastAPI app...")
# Suppress warnings
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Load environment variables
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
HF_HOME = os.getenv("HF_HOME", "./cache")
# Set cache directory for Hugging Face
os.environ["HF_HOME"] = HF_HOME
# Ensure HF_HOME exists and is writable
if not os.path.exists(HF_HOME):
os.makedirs(HF_HOME, exist_ok=True)
# Validate environment variables
if not TOGETHER_AI_API:
raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")
# Initialize embeddings
try:
embeddings = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1",
model_kwargs={"trust_remote_code": True},
)
except Exception as e:
logger.error(f"Error loading embeddings: {e}")
raise RuntimeError("Error initializing HuggingFaceEmbeddings.")
# Ensure FAISS vectorstore is loaded or created
try:
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
except Exception as e:
logger.error(f"Error loading FAISS vectorstore: {e}")
# If not found, create a new vectorstore
try:
loader = DirectoryLoader('./data')
text_splitter = RecursiveCharacterTextSplitter()
documents = text_splitter.split_documents(loader.load())
db = FAISS.from_documents(documents, embeddings)
db.save_local("ipc_vector_db")
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
except Exception as inner_e:
logger.error(f"Error creating FAISS vectorstore: {inner_e}")
raise RuntimeError("FAISS vectorstore could not be created or loaded.")
# Define the prompt template
prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say "The information is not available in the provided context."
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
</s>[INST]
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
# Initialize the Together API
try:
llm = Together(
model="mistralai/Mistral-7B-Instruct-v0.2",
temperature=0.5,
max_tokens=1024,
together_api_key=TOGETHER_AI_API,
)
except Exception as e:
logger.error(f"Error initializing Together API: {e}")
raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")
# Initialize conversational retrieval chain
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=db_retriever,
combine_docs_chain_kwargs={"prompt": prompt},
)
# Initialize FastAPI app
app = FastAPI()
# Define request and response models
class ChatRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
answer: str
# Health check endpoint
@app.get("/")
async def root():
return {"message": "Hello, World!"}
# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
try:
# Pass the user question
result = qa.invoke(input=request.question)
answer = result.get("answer", "The chatbot could not generate a response.")
return ChatResponse(answer=answer)
except Exception as e:
logger.error(f"Error during chat invocation: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Start Uvicorn server if run directly
if __name__ == "__main__":
ENV = os.getenv("ENV", "prod")
PORT = int(os.environ.get("PORT", 10000)) # Use the default port 10000 or the environment port
uvicorn.run("main:app", host="0.0.0.0", port=PORT, reload=(ENV == "dev"))
|