ChatBot / app.py
sanjeevbora's picture
updated app.py
618013c verified
raw
history blame
3.76 kB
import os
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain_community.document_loaders import DirectoryLoader
import torch
import re
import transformers
import spaces
import requests
# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db")
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings)
books_db_client = books_db.as_retriever()
# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
config=model_config,
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=True,
torch_dtype=torch.float16,
device_map=device,
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
max_new_tokens=256
)
llm = HuggingFacePipeline(pipeline=query_pipeline)
books_db_client_retriever = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=books_db_client,
verbose=True
)
# Function to retrieve answer using the RAG system
@spaces.GPU(duration=60)
def test_rag(query):
books_retriever = books_db_client_retriever.run(query)
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
if corrected_text_match:
corrected_text_books = corrected_text_match.group(1).strip()
else:
corrected_text_books = "No helpful answer found."
return corrected_text_books
# OAuth Login Functionality
def oauth_login():
client_id = os.getenv("OAUTH_CLIENT_ID")
redirect_uri = f"https://{os.getenv('SPACE_HOST')}/login/callback"
state = "random_string" # Ideally generate a secure random string
login_url = f"https://huggingface.co/oauth/authorize?redirect_uri={redirect_uri}&scope=openid%20profile&client_id={client_id}&state={state}"
return login_url
# Define the Gradio interface
def chat(query, history=None):
if history is None:
history = []
if query:
answer = test_rag(query)
history.append((query, answer))
return history, "" # Clear input after submission
# Function to clear input text
def clear_input():
return "", # Return empty string to clear input field
# Gradio interface
with gr.Blocks() as interface:
gr.Markdown("## RAG Chatbot")
gr.Markdown("Ask a question and get answers based on retrieved documents.")
input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
submit_btn = gr.Button("Submit")
chat_history = gr.Chatbot(label="Chat History")
# Sign-In Button
login_btn = gr.Button("Sign In with HF")
login_btn.click(lambda: oauth_login(), outputs=None) # Redirect user for OAuth login
submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
interface.launch()