sanjeevbora commited on
Commit
41b8230
·
verified ·
1 Parent(s): c669e82

update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -59
app.py CHANGED
@@ -1,63 +1,102 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ # import fitz # PyMuPDF for extracting text from PDFs
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import Chroma
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.docstore.document import Document
7
+ from langchain.llms import HuggingFacePipeline
8
+ from langchain.chains import RetrievalQA
9
+ from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
10
+ import torch
11
+ import re
12
+ import transformers
13
+ from torch import bfloat16
14
+ from langchain_community.document_loaders import DirectoryLoader
15
+
16
+ # Initialize embeddings and ChromaDB
17
+ model_name = "sentence-transformers/all-mpnet-base-v2"
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model_kwargs = {"device": device}
20
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
21
+
22
+ loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
23
+ docs = loader.load()
24
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
25
+ all_splits = text_splitter.split_documents(docs)
26
+ vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="example_chroma_companies")
27
+ books_db = Chroma(persist_directory="./example_chroma_companies", embedding_function=embeddings)
28
+
29
+ books_db_client = books_db.as_retriever()
30
+
31
+ # Initialize the model and tokenizer
32
+ model_name = "stabilityai/stablelm-zephyr-3b"
33
+
34
+ bnb_config = transformers.BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_quant_type='nf4',
37
+ bnb_4bit_use_double_quant=True,
38
+ bnb_4bit_compute_dtype=torch.bfloat16
39
+ )
40
+
41
+ model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
42
+ model = transformers.AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ trust_remote_code=True,
45
+ config=model_config,
46
+ quantization_config=bnb_config,
47
+ device_map=device,
48
+ )
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
+ query_pipeline = transformers.pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ return_full_text=True,
57
+ torch_dtype=torch.float16,
58
+ device_map=device,
59
+ temperature=0.7,
60
+ top_p=0.9,
61
+ top_k=50,
62
+ max_new_tokens=256
63
+ )
64
+
65
+ llm = HuggingFacePipeline(pipeline=query_pipeline)
66
+
67
+ books_db_client_retriever = RetrievalQA.from_chain_type(
68
+ llm=llm,
69
+ chain_type="stuff",
70
+ retriever=books_db_client,
71
+ verbose=True
72
  )
73
 
74
+ # Function to retrieve answer using the RAG system
75
+ def test_rag(query):
76
+ books_retriever = books_db_client_retriever.run(query)
77
+
78
+ # Extract the relevant answer using regex
79
+ corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
80
+
81
+ if corrected_text_match:
82
+ corrected_text_books = corrected_text_match.group(1).strip()
83
+ else:
84
+ corrected_text_books = "No helpful answer found."
85
+
86
+ return corrected_text_books
87
+
88
+ # Define the Gradio interface
89
+ def chat(query, history=[]):
90
+ answer = test_rag(query)
91
+ history.append((query, answer))
92
+ return history, history
93
+
94
+ # Gradio interface
95
+ interface = gr.Interface(
96
+ fn=chat,
97
+ inputs=[gr.Textbox(label="Enter your question"), gr.State()],
98
+ outputs=[gr.Chatbot(label="Chat History"), gr.State()],
99
+ live=True
100
+ )
101
 
102
+ interface.launch()