Spaces:
Paused
Paused
File size: 4,464 Bytes
618013c 84133c2 8841b3b 84133c2 7880dfb 981eaee 84133c2 41b8230 05bf013 41b8230 125b60f 3d03f6e 41b8230 84133c2 41b8230 84133c2 41b8230 43c1570 41b8230 53a5038 74cbdcf 46486f5 fe10fad 33a3ad3 7880dfb 691a1ad 7880dfb 691a1ad 618013c 7db23d2 618013c 74cbdcf 33a3ad3 74cbdcf 33a3ad3 84133c2 a75d764 691a1ad 84133c2 7880dfb 2512a13 7880dfb 8c121b9 7880dfb 8c121b9 7db23d2 618013c 4164239 7880dfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain_community.document_loaders import DirectoryLoader
import torch
import re
import transformers
from urllib.parse import urlencode
import spaces
# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db")
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings)
books_db_client = books_db.as_retriever()
# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
config=model_config,
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=True,
torch_dtype=torch.float16,
device_map=device,
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
max_new_tokens=256
)
llm = HuggingFacePipeline(pipeline=query_pipeline)
books_db_client_retriever = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=books_db_client,
verbose=True
)
# Function to retrieve answer using the RAG system
@spaces.GPU(duration=60)
def test_rag(query):
books_retriever = books_db_client_retriever.run(query)
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
if corrected_text_match:
corrected_text_books = corrected_text_match.group(1).strip()
else:
corrected_text_books = "No helpful answer found."
return corrected_text_books
# OAuth Configuration
TENANT_ID = os.getenv("TENANT_ID")
CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
REDIRECT_URI = os.getenv("SPACE_HOST") # Make sure this is the correct redirect URI
AUTH_URL = os.getenv("AUTH_URL")
TOKEN_URL = os.getenv("TOKEN_URL")
SCOPE = os.getenv("SCOPE")
access_token = None
# OAuth Login Functionality
def oauth_login():
params = {
'client_id': CLIENT_ID,
'response_type': 'code',
'redirect_uri': REDIRECT_URI,
'response_mode': 'query',
'scope': SCOPE,
'state': 'random_state_string' # Optional: Use for security
}
login_url = f"{AUTH_URL}?{urlencode(params)}"
return login_url
# Define the Gradio interface
def chat(query, history=None):
if history is None:
history = []
if query:
answer = test_rag(query)
history.append((query, answer))
return history, "" # Clear input after submission
# Function to clear input text
def clear_input():
return "", # Return empty string to clear input field
with gr.Blocks() as interface:
gr.Markdown("## RAG Chatbot")
gr.Markdown("Ask a question and get answers based on retrieved documents.")
# Sign-In Button
login_btn = gr.Button("Sign In with HF")
# Redirect to OAuth login
login_btn.click(lambda: f"window.open('{oauth_login()}')", outputs=None)
# Hidden components initially
input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...", visible=False)
submit_btn = gr.Button("Submit", visible=False)
chat_history = gr.Chatbot(label="Chat History", visible=False)
# Show components after login
def show_components():
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
# After a successful login, show the input box and buttons
submit_btn.click(show_components, outputs=[input_box, submit_btn, chat_history])
submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
interface.launch()
|