import os import gradio as gr from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from langchain.llms import HuggingFacePipeline from langchain.chains import RetrievalQA from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM from langchain_community.document_loaders import DirectoryLoader import torch import re import transformers import spaces # Initialize embeddings and ChromaDB model_name = "sentence-transformers/all-mpnet-base-v2" device = "cuda" if torch.cuda.is_available() else "cpu" model_kwargs = {"device": device} embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True) docs = loader.load() vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db") books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings) books_db_client = books_db.as_retriever() # Initialize the model and tokenizer model_name = "stabilityai/stablelm-zephyr-3b" model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024) model = transformers.AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, config=model_config, device_map=device, ) tokenizer = AutoTokenizer.from_pretrained(model_name) query_pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, return_full_text=True, torch_dtype=torch.float16, device_map=device, do_sample=True, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=256 ) llm = HuggingFacePipeline(pipeline=query_pipeline) books_db_client_retriever = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=books_db_client, verbose=True ) # Function to retrieve answer using the RAG system @spaces.GPU(duration=60) def test_rag(query): books_retriever = books_db_client_retriever.run(query) corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL) if corrected_text_match: corrected_text_books = corrected_text_match.group(1).strip() else: corrected_text_books = "No helpful answer found." return corrected_text_books TENANT_ID = os.getenv("TENANT_ID") CLIENT_ID = os.getenv("OAUTH_CLIENT_ID") CLIENT_SECRET = os.getenv("CLIENT_SECRET") REDIRECT_URI = os.getenv("SPACE_HOST") AUTH_URL = os.getenv("AUTH_URL") TOKEN_URL = os.getenv("TOKEN_URL") SCOPE = os.getenv("SCOPE") access_token = None # OAuth Login Functionality def oauth_login(): params = { 'client_id': CLIENT_ID, 'response_type': 'code', 'redirect_uri': REDIRECT_URI, 'response_mode': 'query', 'scope': SCOPE, 'state': 'random_state_string' # Optional: Use for security } login_url = f"{AUTH_URL}?{urlencode(params)}" return login_url # Define the Gradio interface def chat(query, history=None): if history is None: history = [] if query: answer = test_rag(query) history.append((query, answer)) return history, "" # Clear input after submission # Function to clear input text def clear_input(): return "", # Return empty string to clear input field # Function to show/hide components based on login status def on_login(success): return gr.update(visible=success), gr.update(visible=success) with gr.Blocks() as interface: gr.Markdown("## RAG Chatbot") gr.Markdown("Ask a question and get answers based on retrieved documents.") # Sign-In Button login_btn = gr.Button("Sign In with HF") login_btn.click(lambda: oauth_login(), outputs=None) # Redirect user for OAuth login # Hidden components initially input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...", visible=False) submit_btn = gr.Button("Submit", visible=False) chat_history = gr.Chatbot(label="Chat History", visible=False) # Show components after login def show_components(): return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # After a successful login, show the input box and buttons submit_btn.click(show_components, outputs=[input_box, submit_btn, chat_history]) submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box]) interface.launch()