import os import gradio as gr from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from langchain_huggingface import HuggingFacePipeline from langchain.chains import RetrievalQA from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM from langchain_community.document_loaders import DirectoryLoader import torch import re import transformers import spaces from urllib.parse import urlencode # Initialize embeddings and ChromaDB model_name = "sentence-transformers/all-mpnet-base-v2" device = "cuda" if torch.cuda.is_available() else "cpu" model_kwargs = {"device": device} embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True) docs = loader.load() vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db") books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings) books_db_client = books_db.as_retriever() # Initialize the model and tokenizer model_name = "stabilityai/stablelm-zephyr-3b" model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024) model = transformers.AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, config=model_config, device_map=device, ) tokenizer = AutoTokenizer.from_pretrained(model_name) query_pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, return_full_text=True, torch_dtype=torch.float16, device_map=device, do_sample=True, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=256 ) llm = HuggingFacePipeline(pipeline=query_pipeline) books_db_client_retriever = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=books_db_client, verbose=True ) # Function to retrieve answer using the RAG system @spaces.GPU(duration=60) def test_rag(query): books_retriever = books_db_client_retriever.run(query) corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL) if corrected_text_match: corrected_text_books = corrected_text_match.group(1).strip() else: corrected_text_books = "No helpful answer found." return corrected_text_books # OAuth Configuration TENANT_ID = os.getenv("TENANT_ID") CLIENT_ID = os.getenv("OAUTH_CLIENT_ID") CLIENT_SECRET = os.getenv("CLIENT_SECRET") REDIRECT_URI = os.getenv("SPACE_HOST") # Make sure this is the correct redirect URI AUTH_URL = os.getenv("AUTH_URL") TOKEN_URL = os.getenv("TOKEN_URL") SCOPE = os.getenv("SCOPE") access_token = None # OAuth Login Functionality def oauth_login(): params = { 'client_id': CLIENT_ID, 'response_type': 'code', 'redirect_uri': REDIRECT_URI, 'response_mode': 'query', 'scope': SCOPE, 'state': 'random_state_string' # Optional: Use for security } login_url = f"{AUTH_URL}?{urlencode(params)}" return login_url # Define the Gradio interface def chat(query, history=None): if history is None: history = [] if query: answer = test_rag(query) history.append((query, answer)) return history, "" # Clear input after submission # Function to clear input text def clear_input(): return "", # Return empty string to clear input field with gr.Blocks() as interface: gr.Markdown("## RAG Chatbot") gr.Markdown("Ask a question and get answers based on retrieved documents.") # Sign-In Button login_btn = gr.Button("Sign In with HF") # Redirect to OAuth login login_btn.click(lambda: oauth_login(), outputs=None) # This will return the login URL # Hidden components initially input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...", visible=False) submit_btn = gr.Button("Submit", visible=False) chat_history = gr.Chatbot(label="Chat History", visible=False) # Show components after login def show_components(): return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # After a successful login, show the input box and buttons submit_btn.click(show_components, outputs=[input_box, submit_btn, chat_history]) submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box]) interface.launch()