ChatBot / app.py
sanjeevbora's picture
updated app.py
7880dfb verified
raw
history blame
4.49 kB
import os
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain_community.document_loaders import DirectoryLoader
import torch
import re
import transformers
import spaces
from urllib.parse import urlencode
# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db")
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings)
books_db_client = books_db.as_retriever()
# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
config=model_config,
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=True,
torch_dtype=torch.float16,
device_map=device,
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
max_new_tokens=256
)
llm = HuggingFacePipeline(pipeline=query_pipeline)
books_db_client_retriever = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=books_db_client,
verbose=True
)
# Function to retrieve answer using the RAG system
@spaces.GPU(duration=60)
def test_rag(query):
books_retriever = books_db_client_retriever.run(query)
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
if corrected_text_match:
corrected_text_books = corrected_text_match.group(1).strip()
else:
corrected_text_books = "No helpful answer found."
return corrected_text_books
# OAuth Configuration
TENANT_ID = os.getenv("TENANT_ID")
CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
REDIRECT_URI = os.getenv("SPACE_HOST") # Make sure this is the correct redirect URI
AUTH_URL = os.getenv("AUTH_URL")
TOKEN_URL = os.getenv("TOKEN_URL")
SCOPE = os.getenv("SCOPE")
access_token = None
# OAuth Login Functionality
def oauth_login():
params = {
'client_id': CLIENT_ID,
'response_type': 'code',
'redirect_uri': REDIRECT_URI,
'response_mode': 'query',
'scope': SCOPE,
'state': 'random_state_string' # Optional: Use for security
}
login_url = f"{AUTH_URL}?{urlencode(params)}"
return login_url
# Define the Gradio interface
def chat(query, history=None):
if history is None:
history = []
if query:
answer = test_rag(query)
history.append((query, answer))
return history, "" # Clear input after submission
# Function to clear input text
def clear_input():
return "", # Return empty string to clear input field
with gr.Blocks() as interface:
gr.Markdown("## RAG Chatbot")
gr.Markdown("Ask a question and get answers based on retrieved documents.")
# Sign-In Button
login_btn = gr.Button("Sign In with HF")
# Redirect to OAuth login
login_btn.click(lambda: oauth_login(), outputs=None) # This will return the login URL
# Hidden components initially
input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...", visible=False)
submit_btn = gr.Button("Submit", visible=False)
chat_history = gr.Chatbot(label="Chat History", visible=False)
# Show components after login
def show_components():
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
# After a successful login, show the input box and buttons
submit_btn.click(show_components, outputs=[input_box, submit_btn, chat_history])
submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
interface.launch()