Spaces:
Paused
Paused
File size: 3,865 Bytes
bd937b0 c3ef1a1 bd937b0 6efd821 bd937b0 53a5038 4a30cca 41b8230 05bf013 41b8230 e0e01c8 41b8230 f9a5c3b 41b8230 d8e26f2 41b8230 d8e26f2 41b8230 53a5038 41b8230 c42c5d9 41b8230 d8e26f2 41b8230 53a5038 41b8230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# import subprocess
import os
# # Run setup.sh script before starting the app
# subprocess.run(["/bin/bash", "setup.sh"], check=True)
os.system('pip install --upgrade pip')
os.system('apt-get update && apt-get install -y libmagic1')
os.system('pip install -U langchain-community')
os.system('pip install --upgrade accelerate')
os.system('pip install -i https://pypi.org/simple/ bitsandbytes --upgrade')
import gradio as gr
import spaces
# import fitz # PyMuPDF for extracting text from PDFs
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
import torch
import re
import transformers
from torch import bfloat16
from langchain_community.document_loaders import DirectoryLoader
# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cuda"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="example_chroma_companies")
books_db = Chroma(persist_directory="./example_chroma_companies", embedding_function=embeddings)
books_db_client = books_db.as_retriever()
# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"
# bnb_config = transformers.BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type='nf4',
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=torch.bfloat16
# )
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
config=model_config,
# quantization_config=bnb_config,
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=True,
torch_dtype=torch.float16,
device_map=device,
do_sample=True, # Enable sampling
temperature=0.7, # Keep if sampling is used
top_p=0.9,
top_k=50,
max_new_tokens=256
)
llm = HuggingFacePipeline(pipeline=query_pipeline)
books_db_client_retriever = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=books_db_client,
verbose=True
)
# Function to retrieve answer using the RAG system
@spaces.GPU()
def test_rag(query):
books_retriever = books_db_client_retriever.run(query)
# Extract the relevant answer using regex
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
if corrected_text_match:
corrected_text_books = corrected_text_match.group(1).strip()
else:
corrected_text_books = "No helpful answer found."
return corrected_text_books
# Define the Gradio interface
def chat(query, history=None):
if history is None:
history = []
answer = test_rag(query)
history.append((query, answer))
return history, history
# Gradio interface
interface = gr.Interface(
fn=chat,
inputs=[gr.Textbox(label="Enter your question"), gr.State()],
outputs=[gr.Chatbot(label="Chat History"), gr.State()],
live=True
)
interface.launch()
|