delphiclinic commited on
Commit
a054510
·
verified ·
1 Parent(s): facde92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -59
app.py CHANGED
@@ -1,63 +1,149 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
+ from langchain_community.document_loaders import UnstructuredPowerPointLoader
2
+ from langchain_community.vectorstores import FAISS
3
+ from threading import Thread
4
  import gradio as gr
5
+ from queue import SimpleQueue
6
+ from typing import Any, Dict, List, Union
7
+ from langchain.callbacks.base import BaseCallbackHandler
8
+ from langchain.schema import LLMResult
9
+ # from langchain_community.llms import HuggingFaceTextGenInference
10
+ from langchain_community.llms import HuggingFaceEndpoint
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
13
+ from langchain_community.vectorstores import FAISS
14
+ from langchain_community.document_loaders import PyPDFLoader
15
+ from dotenv import load_dotenv, find_dotenv
16
+ import pickle
17
+ import os
18
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
19
+
20
+
21
+
22
+ huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
23
+
24
+ # loader = UnstructuredPowerPointLoader("data/Hypertension.ppt")
25
+
26
+ # documents = loader.load_and_split()
27
+
28
+
29
+ # loader = PyPDFDirectoryLoader("wolo/")
30
+
31
+ # documents = loader.load_and_split()
32
+
33
+ # Define model and vector store
34
+
35
+ embeddings = "BAAI/bge-base-en"
36
+ encode_kwargs = {'normalize_embeddings': True}
37
+ model_norm = HuggingFaceBgeEmbeddings(
38
+ model_name=embeddings,
39
+ model_kwargs={'device': 'cpu'},
40
+ encode_kwargs=encode_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
+ # vector_store = FAISS.from_documents(documents, model_norm)
43
+ # job_done = object() # signals the processing is done
44
+
45
+ # # saving the embeddings locally
46
+ # vector_store.save_local("wolo_database")
47
+
48
+ ##loading
49
+ vector_store = FAISS.load_local("wolo_database", model_norm, allow_dangerous_deserialization=True)
50
+ job_done = object()
51
+
52
+
53
+
54
+ # Lets set up our streaming
55
+ class StreamingGradioCallbackHandler(BaseCallbackHandler):
56
+ """Callback handler - works with LLMs that support streaming."""
57
+
58
+ def __init__(self, q: SimpleQueue):
59
+ self.q = q
60
+
61
+ def on_llm_start(
62
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
63
+ ) -> None:
64
+ """Run when LLM starts running."""
65
+ while not self.q.empty():
66
+ try:
67
+ self.q.get(block=False)
68
+ except SimpleQueue.empty:
69
+ continue
70
+
71
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
72
+ """Run on new LLM token. Only available when streaming is enabled."""
73
+ self.q.put(token)
74
+
75
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
76
+ """Run when LLM ends running."""
77
+ self.q.put(job_done)
78
+
79
+ def on_llm_error(
80
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
81
+ ) -> None:
82
+ """Run when LLM errors."""
83
+ self.q.put(job_done)
84
+
85
+
86
+ # Initializes the LLM
87
+ q = SimpleQueue()
88
+
89
+
90
+
91
+ # from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
92
+
93
+ callbacks = [StreamingGradioCallbackHandler(q)]
94
+ llm = HuggingFaceEndpoint(
95
+ endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
96
+ max_new_tokens=512,
97
+ top_k=10,
98
+ top_p=0.95,
99
+ typical_p=0.95,
100
+ temperature=0.01,
101
+ repetition_penalty=1.03,
102
+ callbacks=callbacks,
103
+ streaming=True,
104
+ huggingfacehub_api_token=huggingfacehub_api_token
105
+ )
106
+
107
+ # Define prompts and initialize conversation chain
108
+ prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response."
109
+ chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
110
+ retriever=vector_store.as_retriever(
111
+ search_kwargs={"k": 2}))
112
+
113
+ # Set up chat history and streaming for Gradio Display
114
+ def process_question(question):
115
+ chat_history = []
116
+ full_query = f"{prompt} {question}"
117
+ result = chain({"question": full_query, "chat_history": chat_history})
118
+ return result["answer"]
119
+
120
+
121
+ def add_text(history, text):
122
+ history = history + [(text, None)]
123
+ return history, ""
124
+
125
+
126
+ def streaming_chat(history):
127
+ user_input = history[-1][0]
128
+ thread = Thread(target=process_question, args=(user_input,))
129
+ thread.start()
130
+ history[-1][1] = ""
131
+ while True:
132
+ next_token = q.get(block=True) # Blocks until an input is available
133
+ if next_token is job_done:
134
+ break
135
+ history[-1][1] += next_token
136
+ yield history
137
+ thread.join()
138
+
139
+
140
+ # Creates A gradio Interface
141
+ with gr.Blocks(title="Clinical Decision Support System", head="intro.html") as demo:
142
+ Langchain = gr.Chatbot(label="Response", height=500)
143
+ Question = gr.Textbox(label="Question")
144
+ Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then(
145
+ streaming_chat, Langchain, Langchain
146
+ )
147
+ demo.queue().launch(share=True, debug=True, favicon_path ='thumbnail.jpg')
148
 
149