Spaces:
Paused
Paused
import subprocess | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.llms import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM | |
from langchain_community.document_loaders import DirectoryLoader | |
import torch | |
import re | |
import requests | |
from urllib.parse import urlencode, urlparse, parse_qs | |
import spaces | |
# Step 1: Run the setup script | |
script_path = './setup.sh' # Adjust the path if needed | |
# Run the script | |
exit_code = subprocess.call(['bash', script_path]) | |
if exit_code == 0: | |
print("Script executed successfully.") | |
else: | |
print(f"Script failed with exit code {exit_code}.") | |
# Initialize embeddings and ChromaDB | |
model_name = "sentence-transformers/all-mpnet-base-v2" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_kwargs = {"device": device} | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True) | |
docs = loader.load() | |
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db") | |
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings) | |
books_db_client = books_db.as_retriever() | |
# Initialize the model and tokenizer | |
model_name = "stabilityai/stablelm-zephyr-3b" | |
model_config = AutoConfig.from_pretrained(model_name, max_new_tokens=1024) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
config=model_config, | |
device_map=device, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
query_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=True, | |
torch_dtype=torch.float16, | |
device_map=device, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=50, | |
max_new_tokens=256 | |
) | |
llm = HuggingFacePipeline(pipeline=query_pipeline) | |
books_db_client_retriever = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=books_db_client, | |
verbose=True | |
) | |
# OAuth Configuration | |
TENANT_ID = '2b093ced-2571-463f-bc3e-b4f8bcb427ee' | |
CLIENT_ID = '2a7c884c-942d-49e2-9e5d-7a29d8a0d3e5' | |
CLIENT_SECRET = 'EOF8Q~kKHCRgx8tnlLM-H8e93ifetxI6x7sU6bGW' | |
REDIRECT_URI = 'https://sanjeevbora-chatbot.hf.space/' | |
AUTH_URL = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize" | |
TOKEN_URL = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/token" | |
params = { | |
'client_id': CLIENT_ID, | |
'response_type': 'code', | |
'redirect_uri': REDIRECT_URI, | |
'response_mode': 'query', | |
'scope': 'User.Read', | |
'state': '12345', | |
'prompt': 'login' # This ensures the login prompt appears even if already logged in | |
} | |
# Construct the login URL | |
login_url = f"{AUTH_URL}?{urlencode(params)}" | |
def show_login_button(): | |
return f'<a href="{login_url}" class="GFG">Click here to login with Microsoft</a>' | |
# Function to exchange auth code for token | |
def exchange_code_for_token(auth_code): | |
data = { | |
'grant_type': 'authorization_code', | |
'client_id': CLIENT_ID, | |
'client_secret': CLIENT_SECRET, | |
'code': auth_code, | |
'redirect_uri': REDIRECT_URI | |
} | |
response = requests.post(TOKEN_URL, data=data) | |
if response.status_code == 200: | |
token_data = response.json() | |
access_token = token_data.get('access_token') | |
return access_token | |
else: | |
return None | |
# Function to handle redirect URL and extract the auth code | |
def handle_redirect(url): | |
parsed_url = urlparse(url) | |
query_params = parse_qs(parsed_url.query) | |
auth_code = query_params.get('code') | |
if auth_code: | |
token = exchange_code_for_token(auth_code[0]) | |
if token: | |
return "Logged in", True # Successfully logged in | |
return "Login failed", False | |
# Function to retrieve answer using the RAG system | |
def test_rag(query): | |
books_retriever = books_db_client_retriever.run(query) | |
# Extract the relevant answer using regex | |
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL) | |
if corrected_text_match: | |
corrected_text_books = corrected_text_match.group(1).strip() | |
else: | |
corrected_text_books = "No helpful answer found." | |
return corrected_text_books | |
# Define the Gradio interface | |
def chat(query, history=None): | |
if history is None: | |
history = [] | |
if query: | |
answer = test_rag(query) | |
history.append((query, answer)) | |
return history, "" # Clear input after submission | |
# Function to clear input text | |
def clear_input(): | |
return "", # Return empty string to clear input field | |
# Gradio Interface | |
with gr.Blocks() as interface: | |
with gr.Tab("Login"): | |
gr.Markdown("## Login Page") | |
login_link = gr.HTML(show_login_button()) | |
# Hidden textbox for redirect URL | |
redirect_url_input = gr.Textbox(label="Redirect URL", visible=True) # URL from the Microsoft redirect | |
status_label = gr.Label(value="Not logged in") # Label to show login status | |
def on_redirect(redirect_url): | |
# Extract and exchange token | |
status, logged_in = handle_redirect(redirect_url) | |
if logged_in: | |
return gr.update(visible=False), status, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
else: | |
return gr.update(visible=True), status, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
# Handle redirect and switch to chatbot page upon login | |
redirect_url_input.change( | |
on_redirect, | |
inputs=[redirect_url_input], | |
outputs=[redirect_url_input, status_label, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Chatbot.update(visible=True)], | |
show_progress=True | |
) | |
with gr.Tab("Chatbot"): | |
gr.Markdown("## Chatbot Page") | |
# Components for chat (initially hidden) | |
input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...", visible=False) | |
submit_btn = gr.Button("Submit", visible=False) | |
chat_history = gr.Chatbot(label="Chat History", visible=False) | |
# Chat submission | |
submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box]) | |
interface.launch() |