File size: 3,865 Bytes
bd937b0
 
 
 
c3ef1a1
bd937b0
6efd821
 
 
bd937b0
53a5038
4a30cca
41b8230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05bf013
 
41b8230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0e01c8
 
 
 
 
 
41b8230
 
 
 
 
 
f9a5c3b
41b8230
 
 
 
 
 
 
 
 
 
 
 
d8e26f2
 
41b8230
 
 
 
 
d8e26f2
41b8230
 
 
 
 
 
 
53a5038
 
41b8230
c42c5d9
41b8230
 
 
 
 
 
 
 
 
 
 
 
 
 
d8e26f2
 
 
41b8230
 
 
 
 
 
 
 
 
 
 
53a5038
41b8230
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# import subprocess
import os
# # Run setup.sh script before starting the app
# subprocess.run(["/bin/bash", "setup.sh"], check=True)
os.system('pip install --upgrade pip')
os.system('apt-get update && apt-get install -y libmagic1')
os.system('pip install -U langchain-community')
os.system('pip install --upgrade accelerate')
os.system('pip install -i https://pypi.org/simple/ bitsandbytes --upgrade')

import gradio as gr
import spaces
# import fitz  # PyMuPDF for extracting text from PDFs
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
import torch
import re
import transformers
from torch import bfloat16
from langchain_community.document_loaders import DirectoryLoader

# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cuda"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="example_chroma_companies")
books_db = Chroma(persist_directory="./example_chroma_companies", embedding_function=embeddings)

books_db_client = books_db.as_retriever()

# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"

# bnb_config = transformers.BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    config=model_config,
    # quantization_config=bnb_config,
    device_map=device,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

query_pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    torch_dtype=torch.float16,
    device_map=device,
    do_sample=True,  # Enable sampling
    temperature=0.7,  # Keep if sampling is used
    top_p=0.9,
    top_k=50,
    max_new_tokens=256
)


llm = HuggingFacePipeline(pipeline=query_pipeline)

books_db_client_retriever = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=books_db_client,
    verbose=True
)

# Function to retrieve answer using the RAG system
@spaces.GPU()
def test_rag(query):
    books_retriever = books_db_client_retriever.run(query)
    
    # Extract the relevant answer using regex
    corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
    
    if corrected_text_match:
        corrected_text_books = corrected_text_match.group(1).strip()
    else:
        corrected_text_books = "No helpful answer found."
    
    return corrected_text_books

# Define the Gradio interface
def chat(query, history=None):
    if history is None:
        history = []
    answer = test_rag(query)
    history.append((query, answer))
    return history, history

# Gradio interface
interface = gr.Interface(
    fn=chat,
    inputs=[gr.Textbox(label="Enter your question"), gr.State()],
    outputs=[gr.Chatbot(label="Chat History"), gr.State()],
    live=True
)

interface.launch()