sanjeevbora commited on
Commit
43c1570
·
verified ·
1 Parent(s): 4d35139

Updated UI

Browse files
Files changed (1) hide show
  1. app.py +29 -44
app.py CHANGED
@@ -1,61 +1,38 @@
1
- # import subprocess
2
- import os
3
- # # Run setup.sh script before starting the app
4
- # subprocess.run(["/bin/bash", "setup.sh"], check=True)
5
- # os.system('pip install --upgrade pip')
6
- # os.system('apt-get update && apt-get install -y libmagic1')
7
- # os.system('pip install -U langchain-community')
8
- # os.system('pip install --upgrade accelerate')
9
- # os.system('pip install -i https://pypi.org/simple/ bitsandbytes --upgrade')
10
-
11
  import gradio as gr
12
- import spaces
13
- # import fitz # PyMuPDF for extracting text from PDFs
14
  from langchain.embeddings import HuggingFaceEmbeddings
15
  from langchain.vectorstores import Chroma
16
- from langchain.text_splitter import RecursiveCharacterTextSplitter
17
- from langchain.docstore.document import Document
18
  from langchain.llms import HuggingFacePipeline
19
  from langchain.chains import RetrievalQA
20
  from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
21
  import torch
22
  import re
23
  import transformers
24
- from torch import bfloat16
25
- from langchain_community.document_loaders import DirectoryLoader
26
 
27
  # Initialize embeddings and ChromaDB
28
  model_name = "sentence-transformers/all-mpnet-base-v2"
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
- # device = "cuda"
31
  model_kwargs = {"device": device}
32
  embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
33
 
34
- loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
35
- docs = loader.load()
36
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
37
- all_splits = text_splitter.split_documents(docs)
38
- vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="example_chroma_companies")
39
- books_db = Chroma(persist_directory="./example_chroma_companies", embedding_function=embeddings)
40
-
41
  books_db_client = books_db.as_retriever()
42
 
43
  # Initialize the model and tokenizer
44
  model_name = "stabilityai/stablelm-zephyr-3b"
45
 
46
- # bnb_config = transformers.BitsAndBytesConfig(
47
- # load_in_4bit=True,
48
- # bnb_4bit_quant_type='nf4',
49
- # bnb_4bit_use_double_quant=True,
50
- # bnb_4bit_compute_dtype=torch.bfloat16
51
- # )
52
 
53
  model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
54
  model = transformers.AutoModelForCausalLM.from_pretrained(
55
  model_name,
56
  trust_remote_code=True,
57
  config=model_config,
58
- # quantization_config=bnb_config,
59
  device_map=device,
60
  )
61
 
@@ -68,14 +45,13 @@ query_pipeline = transformers.pipeline(
68
  return_full_text=True,
69
  torch_dtype=torch.float16,
70
  device_map=device,
71
- do_sample=True, # Enable sampling
72
- temperature=0.7, # Keep if sampling is used
73
  top_p=0.9,
74
  top_k=50,
75
  max_new_tokens=256
76
  )
77
 
78
-
79
  llm = HuggingFacePipeline(pipeline=query_pipeline)
80
 
81
  books_db_client_retriever = RetrievalQA.from_chain_type(
@@ -86,7 +62,6 @@ books_db_client_retriever = RetrievalQA.from_chain_type(
86
  )
87
 
88
  # Function to retrieve answer using the RAG system
89
- @spaces.GPU(duration=120)
90
  def test_rag(query):
91
  books_retriever = books_db_client_retriever.run(query)
92
 
@@ -104,16 +79,26 @@ def test_rag(query):
104
  def chat(query, history=None):
105
  if history is None:
106
  history = []
107
- answer = test_rag(query)
108
- history.append((query, answer))
109
- return history, history
 
 
 
 
 
110
 
111
  # Gradio interface
112
- interface = gr.Interface(
113
- fn=chat,
114
- inputs=[gr.Textbox(label="Enter your question"), gr.State()],
115
- outputs=[gr.Chatbot(label="Chat History"), gr.State()],
116
- live=True
117
- )
 
 
 
 
 
118
 
119
  interface.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.vectorstores import Chroma
 
 
4
  from langchain.llms import HuggingFacePipeline
5
  from langchain.chains import RetrievalQA
6
  from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
7
  import torch
8
  import re
9
  import transformers
 
 
10
 
11
  # Initialize embeddings and ChromaDB
12
  model_name = "sentence-transformers/all-mpnet-base-v2"
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
  model_kwargs = {"device": device}
15
  embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
16
 
17
+ books_db = Chroma(persist_directory="./chroma_companies", embedding_function=embeddings)
 
 
 
 
 
 
18
  books_db_client = books_db.as_retriever()
19
 
20
  # Initialize the model and tokenizer
21
  model_name = "stabilityai/stablelm-zephyr-3b"
22
 
23
+ bnb_config = transformers.BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_quant_type='nf4',
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_compute_dtype=torch.bfloat16
28
+ )
29
 
30
  model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
31
  model = transformers.AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
  trust_remote_code=True,
34
  config=model_config,
35
+ quantization_config=bnb_config,
36
  device_map=device,
37
  )
38
 
 
45
  return_full_text=True,
46
  torch_dtype=torch.float16,
47
  device_map=device,
48
+ do_sample=True,
49
+ temperature=0.7,
50
  top_p=0.9,
51
  top_k=50,
52
  max_new_tokens=256
53
  )
54
 
 
55
  llm = HuggingFacePipeline(pipeline=query_pipeline)
56
 
57
  books_db_client_retriever = RetrievalQA.from_chain_type(
 
62
  )
63
 
64
  # Function to retrieve answer using the RAG system
 
65
  def test_rag(query):
66
  books_retriever = books_db_client_retriever.run(query)
67
 
 
79
  def chat(query, history=None):
80
  if history is None:
81
  history = []
82
+ if query:
83
+ answer = test_rag(query)
84
+ history.append((query, answer))
85
+ return history, "" # Clear input after submission
86
+
87
+ # Function to clear input text
88
+ def clear_input():
89
+ return "", # Return empty string to clear input field
90
 
91
  # Gradio interface
92
+ with gr.Blocks() as interface:
93
+ gr.Markdown("## RAG Chatbot")
94
+ gr.Markdown("Ask a question and get answers based on retrieved documents.")
95
+
96
+ input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
97
+ submit_btn = gr.Button("Submit")
98
+ # clear_btn = gr.Button("Clear")
99
+ chat_history = gr.Chatbot(label="Chat History")
100
+
101
+ submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
102
+ # clear_btn.click(clear_input, outputs=input_box)
103
 
104
  interface.launch()