File size: 3,553 Bytes
53a5038
41b8230
 
 
 
 
5cf3079
41b8230
 
 
7bc17d6
41b8230
 
 
05bf013
41b8230
 
 
125b60f
3d03f6e
 
 
41b8230
 
 
 
 
90dbd88
 
 
 
 
 
41b8230
 
 
 
 
 
90dbd88
41b8230
 
 
 
 
 
 
 
 
 
 
 
43c1570
 
41b8230
 
 
 
 
 
 
 
 
 
 
 
53a5038
 
41b8230
7bc17d6
41b8230
 
 
 
 
 
 
 
 
 
 
 
 
 
d8e26f2
 
 
43c1570
 
 
 
 
 
 
 
41b8230
 
43c1570
 
 
 
 
 
 
 
 
 
 
53a5038
41b8230
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain_community.document_loaders import DirectoryLoader
import torch
import re
import transformers
import spaces

# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db")
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings)
books_db_client = books_db.as_retriever()

# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"

# bnb_config = transformers.BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    config=model_config,
    # quantization_config=bnb_config,
    device_map=device,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

query_pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    torch_dtype=torch.float16,
    device_map=device,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    max_new_tokens=256
)

llm = HuggingFacePipeline(pipeline=query_pipeline)

books_db_client_retriever = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=books_db_client,
    verbose=True
)

# Function to retrieve answer using the RAG system
@spaces.GPU(duration=120)
def test_rag(query):
    books_retriever = books_db_client_retriever.run(query)
    
    # Extract the relevant answer using regex
    corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
    
    if corrected_text_match:
        corrected_text_books = corrected_text_match.group(1).strip()
    else:
        corrected_text_books = "No helpful answer found."
    
    return corrected_text_books

# Define the Gradio interface
def chat(query, history=None):
    if history is None:
        history = []
    if query:
        answer = test_rag(query)
        history.append((query, answer))
    return history, ""  # Clear input after submission

# Function to clear input text
def clear_input():
    return "",  # Return empty string to clear input field

# Gradio interface
with gr.Blocks() as interface:
    gr.Markdown("## RAG Chatbot")
    gr.Markdown("Ask a question and get answers based on retrieved documents.")
    
    input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
    submit_btn = gr.Button("Submit")
    # clear_btn = gr.Button("Clear")
    chat_history = gr.Chatbot(label="Chat History")
    
    submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
    # clear_btn.click(clear_input, outputs=input_box)

interface.launch()