Akjava commited on
Commit
fbeaa20
·
verified ·
1 Parent(s): e761993

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -4
app.py CHANGED
@@ -14,11 +14,75 @@ from llama_cpp_agent.chat_history import BasicChatHistory
14
  from llama_cpp_agent.chat_history.messages import Roles
15
  import gradio as gr
16
  from huggingface_hub import hf_hub_download
17
- from typing import List, Tuple
18
  from logger import logging
19
  from exception import CustomExceptionHandling
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Download gguf model files
23
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
24
 
@@ -88,10 +152,21 @@ def respond(
88
  llm_model = model
89
  provider = LlamaCppPythonProvider(llm)
90
 
 
 
 
 
 
 
 
 
 
 
 
91
  # Create the agent
92
  agent = LlamaCppAgent(
93
  provider,
94
- system_prompt=f"{system_message}",
95
  predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
96
  debug_output=True,
97
  )
@@ -116,7 +191,7 @@ def respond(
116
 
117
  # Get the response stream
118
  stream = agent.get_chat_response(
119
- message,
120
  llm_sampling_settings=settings,
121
  chat_history=messages,
122
  returns_streaming_generator=True,
@@ -141,7 +216,7 @@ def respond(
141
  # Create a chat interface
142
  demo = gr.ChatInterface(
143
  respond,
144
- examples=[["What is the capital of France?"], ["Tell me something about artificial intelligence."], ["What is gravity?"]],
145
  additional_inputs_accordion=gr.Accordion(
146
  label="⚙️ Parameters", open=False, render=False
147
  ),
 
14
  from llama_cpp_agent.chat_history.messages import Roles
15
  import gradio as gr
16
  from huggingface_hub import hf_hub_download
17
+ from typing import List, Tuple,Dict,Optional
18
  from logger import logging
19
  from exception import CustomExceptionHandling
20
 
21
+ from smolagents.gradio_ui import GradioUI
22
+ from smolagents import (
23
+ CodeAgent,
24
+ GoogleSearchTool,
25
+ Model,
26
+ Tool,
27
+ LiteLLMModel,
28
+ ToolCallingAgent,
29
+ ChatMessage,tool,MessageRole
30
+ )
31
+
32
+ cache_file = "docs_processed.joblib"
33
+ if os.path.exists(cache_file):
34
+ docs_processed = joblib.load(cache_file)
35
+ print("Loaded docs_processed from cache.")
36
+ else:
37
+ knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
38
+ source_docs = [
39
+ Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
40
+ ]
41
+
42
+ text_splitter = RecursiveCharacterTextSplitter(
43
+ chunk_size=400,
44
+ chunk_overlap=20,
45
+ add_start_index=True,
46
+ strip_whitespace=True,
47
+ separators=["\n\n", "\n", ".", " ", ""],
48
+ )
49
+ docs_processed = text_splitter.split_documents(source_docs)
50
+ joblib.dump(docs_processed, cache_file)
51
+ print("Created and saved docs_processed to cache.")
52
+
53
+ class RetrieverTool(Tool):
54
+ name = "retriever"
55
+ description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
56
+ inputs = {
57
+ "query": {
58
+ "type": "string",
59
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
60
+ }
61
+ }
62
+ output_type = "string"
63
+
64
+ def __init__(self, docs, **kwargs):
65
+ super().__init__(**kwargs)
66
+
67
+ self.retriever = BM25Retriever.from_documents(
68
+ docs,
69
+ k=7,
70
+ )
71
+
72
+ def forward(self, query: str) -> str:
73
+ assert isinstance(query, str), "Your search query must be a string"
74
 
75
+ docs = self.retriever.invoke(
76
+ query,
77
+ )
78
+ return "\nRetrieved documents:\n" + "".join(
79
+ [
80
+ f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content)
81
+ for i, doc in enumerate(docs)
82
+ ]
83
+ )
84
+
85
+ retriever_tool = RetrieverTool(docs_processed)
86
  # Download gguf model files
87
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
88
 
 
152
  llm_model = model
153
  provider = LlamaCppPythonProvider(llm)
154
 
155
+ text = retriever_tool(query=f"{message}")
156
+
157
+ retriever_system="""
158
+ You are an AI assistant that answers questions based on documents provided by the user. Wait for the user to send a document. Once you receive the document, carefully read its contents and then answer the following question:
159
+
160
+ Question: $s
161
+
162
+ [Wait for user's message containing the document]
163
+ """ % message
164
+
165
+
166
  # Create the agent
167
  agent = LlamaCppAgent(
168
  provider,
169
+ system_prompt=f"{retriever_system}",
170
  predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
171
  debug_output=True,
172
  )
 
191
 
192
  # Get the response stream
193
  stream = agent.get_chat_response(
194
+ text,
195
  llm_sampling_settings=settings,
196
  chat_history=messages,
197
  returns_streaming_generator=True,
 
216
  # Create a chat interface
217
  demo = gr.ChatInterface(
218
  respond,
219
+ examples=[["What is the Transform?"], ["Tell me About Huggng."], ["How to upload dataset?"]],
220
  additional_inputs_accordion=gr.Accordion(
221
  label="⚙️ Parameters", open=False, render=False
222
  ),