File size: 4,464 Bytes
618013c
84133c2
 
 
8841b3b
84133c2
 
 
 
 
 
7880dfb
981eaee
84133c2
41b8230
 
05bf013
41b8230
 
 
125b60f
3d03f6e
 
 
41b8230
 
 
 
84133c2
 
41b8230
 
 
 
 
 
 
84133c2
41b8230
 
 
 
 
 
43c1570
 
41b8230
 
 
 
 
 
 
 
 
 
 
 
53a5038
 
74cbdcf
46486f5
fe10fad
 
 
 
 
 
 
 
 
 
33a3ad3
7880dfb
691a1ad
 
 
7880dfb
691a1ad
 
 
 
 
618013c
 
7db23d2
 
 
 
 
 
 
 
 
618013c
 
74cbdcf
33a3ad3
 
 
 
 
 
 
 
74cbdcf
 
 
 
33a3ad3
84133c2
a75d764
691a1ad
 
 
84133c2
7880dfb
2512a13
7880dfb
8c121b9
 
 
 
7880dfb
8c121b9
7db23d2
 
 
 
 
618013c
4164239
7880dfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain_community.document_loaders import DirectoryLoader
import torch
import re
import transformers
from urllib.parse import urlencode
import spaces

# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db")
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings)
books_db_client = books_db.as_retriever()

# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    config=model_config,
    device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

query_pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    torch_dtype=torch.float16,
    device_map=device,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    max_new_tokens=256
)

llm = HuggingFacePipeline(pipeline=query_pipeline)

books_db_client_retriever = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=books_db_client,
    verbose=True
)

# Function to retrieve answer using the RAG system
@spaces.GPU(duration=60)
def test_rag(query):
    books_retriever = books_db_client_retriever.run(query)
    corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
    
    if corrected_text_match:
        corrected_text_books = corrected_text_match.group(1).strip()
    else:
        corrected_text_books = "No helpful answer found."
    
    return corrected_text_books

# OAuth Configuration
TENANT_ID = os.getenv("TENANT_ID")
CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
REDIRECT_URI = os.getenv("SPACE_HOST")  # Make sure this is the correct redirect URI
AUTH_URL = os.getenv("AUTH_URL")
TOKEN_URL = os.getenv("TOKEN_URL")
SCOPE = os.getenv("SCOPE")
access_token = None

# OAuth Login Functionality
def oauth_login():
    params = {
        'client_id': CLIENT_ID,
        'response_type': 'code',
        'redirect_uri': REDIRECT_URI,
        'response_mode': 'query',
        'scope': SCOPE,
        'state': 'random_state_string'  # Optional: Use for security
    }
    login_url = f"{AUTH_URL}?{urlencode(params)}"
    return login_url

# Define the Gradio interface
def chat(query, history=None):
    if history is None:
        history = []
    if query:
        answer = test_rag(query)
        history.append((query, answer))
    return history, ""  # Clear input after submission

# Function to clear input text
def clear_input():
    return "",  # Return empty string to clear input field

with gr.Blocks() as interface:
    gr.Markdown("## RAG Chatbot")
    gr.Markdown("Ask a question and get answers based on retrieved documents.")

    # Sign-In Button
    login_btn = gr.Button("Sign In with HF")
    
    # Redirect to OAuth login
    login_btn.click(lambda: f"window.open('{oauth_login()}')", outputs=None)

    # Hidden components initially
    input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...", visible=False)
    submit_btn = gr.Button("Submit", visible=False)
    chat_history = gr.Chatbot(label="Chat History", visible=False)

    # Show components after login
    def show_components():
        return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)

    # After a successful login, show the input box and buttons
    submit_btn.click(show_components, outputs=[input_box, submit_btn, chat_history])
    submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])

interface.launch()