Akjava commited on
Commit
58e6047
·
verified ·
1 Parent(s): 9ce7c71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -176
app.py CHANGED
@@ -1,217 +1,283 @@
1
  # Importing required libraries
 
 
 
 
2
  import warnings
3
  warnings.filterwarnings("ignore")
4
-
5
  import os
6
  import json
7
  import subprocess
8
  import sys
 
9
  from llama_cpp import Llama
10
  from llama_cpp_agent import LlamaCppAgent
11
  from llama_cpp_agent import MessagesFormatterType
12
  from llama_cpp_agent.providers import LlamaCppPythonProvider
13
  from llama_cpp_agent.chat_history import BasicChatHistory
14
  from llama_cpp_agent.chat_history.messages import Roles
 
 
15
  import gradio as gr
16
  from huggingface_hub import hf_hub_download
17
- from typing import List, Tuple
18
  from logger import logging
19
  from exception import CustomExceptionHandling
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Download gguf model files
23
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
24
 
25
- hf_hub_download(
26
- repo_id="bartowski/google_gemma-3-1b-it-GGUF",
27
- filename="google_gemma-3-1b-it-Q6_K.gguf",
28
- local_dir="./models",
29
- )
30
  hf_hub_download(
31
  repo_id="bartowski/google_gemma-3-1b-it-GGUF",
32
  filename="google_gemma-3-1b-it-Q5_K_M.gguf",
33
  local_dir="./models",
34
  )
35
 
36
- # Set the title and description
37
- title = "Gemma Llama.cpp"
38
- description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices."""
39
-
40
-
41
- llm = None
42
- llm_model = None
43
-
44
- def respond(
45
- message: str,
46
- history: List[Tuple[str, str]],
47
- model: str,
48
- system_message: str,
49
- max_tokens: int,
50
- temperature: float,
51
- top_p: float,
52
- top_k: int,
53
- repeat_penalty: float,
54
- ):
55
- """
56
- Respond to a message using the Gemma3 model via Llama.cpp.
57
-
58
- Args:
59
- - message (str): The message to respond to.
60
- - history (List[Tuple[str, str]]): The chat history.
61
- - model (str): The model to use.
62
- - system_message (str): The system message to use.
63
- - max_tokens (int): The maximum number of tokens to generate.
64
- - temperature (float): The temperature of the model.
65
- - top_p (float): The top-p of the model.
66
- - top_k (int): The top-k of the model.
67
- - repeat_penalty (float): The repetition penalty of the model.
68
-
69
- Returns:
70
- str: The response to the message.
71
- """
72
- try:
73
- # Load the global variables
74
- global llm
75
- global llm_model
76
-
77
- # Load the model
78
- if llm is None or llm_model != model:
79
- llm = Llama(
80
- model_path=f"models/{model}",
81
  flash_attn=False,
82
  n_gpu_layers=0,
83
  n_batch=8,
84
- n_ctx=2048,
85
  n_threads=2,
86
- n_threads_batch=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
- llm_model = model
89
- provider = LlamaCppPythonProvider(llm)
90
-
91
- # Create the agent
92
- agent = LlamaCppAgent(
93
- provider,
94
- system_prompt=f"{system_message}",
95
- predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
96
- debug_output=True,
97
- )
98
 
99
- # Set the settings like temperature, top-k, top-p, max tokens, etc.
100
- settings = provider.get_provider_default_settings()
101
- settings.temperature = temperature
102
- settings.top_k = top_k
103
- settings.top_p = top_p
104
- settings.max_tokens = max_tokens
105
- settings.repeat_penalty = repeat_penalty
106
- settings.stream = True
107
-
108
- messages = BasicChatHistory()
109
-
110
- # Add the chat history
111
- for msn in history:
112
- user = {"role": Roles.user, "content": msn[0]}
113
- assistant = {"role": Roles.assistant, "content": msn[1]}
114
- messages.add_message(user)
115
- messages.add_message(assistant)
116
-
117
- # Get the response stream
118
- stream = agent.get_chat_response(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  message,
120
  llm_sampling_settings=settings,
121
  chat_history=messages,
122
- returns_streaming_generator=True,
123
- print_output=False,
124
- )
 
 
 
 
125
 
126
- # Log the success
127
- logging.info("Response stream generated successfully")
128
-
129
- # Generate the response
130
- outputs = ""
131
- for output in stream:
132
- outputs += output
133
- yield outputs
134
-
135
- # Handle exceptions that may occur during the process
136
- except Exception as e:
137
- # Custom exception handling
138
- raise CustomExceptionHandling(e, sys) from e
139
-
140
-
141
- # Create a chat interface
142
- demo = gr.ChatInterface(
143
- respond,
144
- examples=[["What is the capital of France?"], ["Tell me something about artificial intelligence."], ["What is gravity?"]],
145
- additional_inputs_accordion=gr.Accordion(
146
- label="⚙️ Parameters", open=False, render=False
147
- ),
148
- additional_inputs=[
149
- gr.Dropdown(
150
- choices=[
151
- "google_gemma-3-1b-it-Q6_K.gguf",
152
- "google_gemma-3-1b-it-Q5_K_M.gguf",
153
- ],
154
- value="google_gemma-3-1b-it-Q5_K_M.gguf",
155
- label="Model",
156
- info="Select the AI model to use for chat",
157
- ),
158
- gr.Textbox(
159
- value="You are a helpful assistant.",
160
- label="System Prompt",
161
- info="Define the AI assistant's personality and behavior",
162
- lines=2,
163
- ),
164
- gr.Slider(
165
- minimum=512,
166
- maximum=2048,
167
- value=1024,
168
- step=1,
169
- label="Max Tokens",
170
- info="Maximum length of response (higher = longer replies)",
171
- ),
172
- gr.Slider(
173
- minimum=0.1,
174
- maximum=2.0,
175
- value=0.7,
176
- step=0.1,
177
- label="Temperature",
178
- info="Creativity level (higher = more creative, lower = more focused)",
179
- ),
180
- gr.Slider(
181
- minimum=0.1,
182
- maximum=1.0,
183
- value=0.95,
184
- step=0.05,
185
- label="Top-p",
186
- info="Nucleus sampling threshold",
187
- ),
188
- gr.Slider(
189
- minimum=1,
190
- maximum=100,
191
- value=40,
192
- step=1,
193
- label="Top-k",
194
- info="Limit vocabulary choices to top K tokens",
195
- ),
196
- gr.Slider(
197
- minimum=1.0,
198
- maximum=2.0,
199
- value=1.1,
200
- step=0.1,
201
- label="Repetition Penalty",
202
- info="Penalize repeated words (higher = less repetition)",
203
- ),
204
- ],
205
- theme="Ocean",
206
- submit_btn="Send",
207
- stop_btn="Stop",
208
- title=title,
209
- description=description,
210
- chatbot=gr.Chatbot(scale=1, show_copy_button=True),
211
- flagging_mode="never",
212
- )
213
 
 
214
 
215
- # Launch the chat interface
216
  if __name__ == "__main__":
217
- demo.launch(debug=False)
 
1
  # Importing required libraries
2
+ from langchain.docstore.document import Document
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.retrievers import BM25Retriever
5
+
6
  import warnings
7
  warnings.filterwarnings("ignore")
8
+ import datasets
9
  import os
10
  import json
11
  import subprocess
12
  import sys
13
+ import joblib
14
  from llama_cpp import Llama
15
  from llama_cpp_agent import LlamaCppAgent
16
  from llama_cpp_agent import MessagesFormatterType
17
  from llama_cpp_agent.providers import LlamaCppPythonProvider
18
  from llama_cpp_agent.chat_history import BasicChatHistory
19
  from llama_cpp_agent.chat_history.messages import Roles
20
+ from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings
21
+
22
  import gradio as gr
23
  from huggingface_hub import hf_hub_download
24
+ from typing import List, Tuple,Dict,Optional
25
  from logger import logging
26
  from exception import CustomExceptionHandling
27
 
28
+ from smolagents.gradio_ui import GradioUI
29
+ from smolagents import (
30
+ CodeAgent,
31
+ GoogleSearchTool,
32
+ Model,
33
+ Tool,
34
+ LiteLLMModel,
35
+ ToolCallingAgent,
36
+ ChatMessage,tool,MessageRole
37
+ )
38
+
39
+ cache_file = "docs_processed.joblib"
40
+ if os.path.exists(cache_file):
41
+ docs_processed = joblib.load(cache_file)
42
+ print("Loaded docs_processed from cache.")
43
+ else:
44
+ knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
45
+ source_docs = [
46
+ Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
47
+ ]
48
+
49
+ text_splitter = RecursiveCharacterTextSplitter(
50
+ chunk_size=400,
51
+ chunk_overlap=20,
52
+ add_start_index=True,
53
+ strip_whitespace=True,
54
+ separators=["\n\n", "\n", ".", " ", ""],
55
+ )
56
+ docs_processed = text_splitter.split_documents(source_docs)
57
+ joblib.dump(docs_processed, cache_file)
58
+ print("Created and saved docs_processed to cache.")
59
+
60
+ class RetrieverTool(Tool):
61
+ name = "retriever"
62
+ description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
63
+ inputs = {
64
+ "query": {
65
+ "type": "string",
66
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
67
+ }
68
+ }
69
+ output_type = "string"
70
+
71
+ def __init__(self, docs, **kwargs):
72
+ super().__init__(**kwargs)
73
+
74
+ self.retriever = BM25Retriever.from_documents(
75
+ docs,
76
+ k=7,
77
+ )
78
+
79
+ def forward(self, query: str) -> str:
80
+ assert isinstance(query, str), "Your search query must be a string"
81
+
82
+ docs = self.retriever.invoke(
83
+ query,
84
+ )
85
+ return "\nRetrieved documents:\n" + "".join(
86
+ [
87
+ f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content)
88
+ for i, doc in enumerate(docs)
89
+ ]
90
+ )
91
+
92
+
93
 
94
  # Download gguf model files
95
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
96
 
97
+ os.makedirs("models",exist_ok=True)
98
+
99
+ logging.info("start download")
 
 
100
  hf_hub_download(
101
  repo_id="bartowski/google_gemma-3-1b-it-GGUF",
102
  filename="google_gemma-3-1b-it-Q5_K_M.gguf",
103
  local_dir="./models",
104
  )
105
 
106
+ retriever_tool = RetrieverTool(docs_processed)
107
+
108
+ # based https://github.com/huggingface/smolagents/pull/450
109
+ # almost overwrite with https://huggingface.co/spaces/sitammeur/Gemma-llamacpp
110
+ class LlamaCppModel(Model):
111
+ def __init__(
112
+ self,
113
+ model_path: Optional[str] = None,
114
+ repo_id: Optional[str] = None,
115
+ filename: Optional[str] = None,
116
+ n_gpu_layers: int = 0,
117
+ n_ctx: int = 8192,
118
+ max_tokens: int = 1024,
119
+ verbose:bool = False,
120
+ **kwargs,
121
+ ):
122
+ """
123
+ Initializes the LlamaCppModel.
124
+
125
+ Parameters:
126
+ model_path (str, optional): Path to the local model file.
127
+ repo_id (str, optional): Hugging Face repository ID if loading from Hugging Face.
128
+ filename (str, optional): Specific filename to load from the repository.
129
+ n_gpu_layers (int, default=0): Number of GPU layers to use.
130
+ n_ctx (int, default=8192): Context size for the model.
131
+ **kwargs: Additional keyword arguments.
132
+ Raises:
133
+ ValueError: If neither model_path nor repo_id+filename are provided.
134
+ """
135
+ from llama_cpp import Llama
136
+ print("init2")
137
+
138
+ super().__init__(**kwargs)
139
+ self.flatten_messages_as_text=True
140
+ self.max_tokens = max_tokens
141
+
142
+ if model_path:
143
+ self.llm = Llama(
144
+ model_path=model_path,
 
 
 
 
 
 
145
  flash_attn=False,
146
  n_gpu_layers=0,
147
  n_batch=8,
148
+ n_ctx=n_ctx,
149
  n_threads=2,
150
+ n_threads_batch=2,verbose=False
151
+ )
152
+
153
+ elif repo_id and filename:
154
+ self.llm = Llama.from_pretrained(
155
+ repo_id=repo_id,
156
+ filename=filename,
157
+ n_gpu_layers=n_gpu_layers,
158
+ n_ctx=n_ctx,
159
+ max_tokens=max_tokens,
160
+ verbose=verbose,
161
+ **kwargs
162
+ )
163
+ else:
164
+ raise ValueError("Must provide either model_path or repo_id+filename")
165
+
166
+ def __call__(
167
+ self,
168
+ messages: List[Dict[str, str]],
169
+ stop_sequences: Optional[List[str]] = None,
170
+ grammar: Optional[str] = None,
171
+ tools_to_call_from: Optional[List[Tool]] = None,
172
+ **kwargs,
173
+ ) -> ChatMessage:
174
+
175
+ """
176
+ Generates a response from the llama.cpp model and integrates tool usage *only if tools are provided*.
177
+ """
178
+
179
+ from llama_cpp import LlamaGrammar
180
+ try:
181
+ completion_kwargs = self._prepare_completion_kwargs(
182
+ messages=messages,
183
+ stop_sequences=stop_sequences,
184
+ grammar=grammar,
185
+ tools_to_call_from=tools_to_call_from,
186
+ #flatten_messages_as_text=True,
187
+ **kwargs
188
+ )
189
+
190
+ if not tools_to_call_from:
191
+ completion_kwargs.pop("tools", None)
192
+ completion_kwargs.pop("tool_choice", None)
193
+
194
+ filtered_kwargs = {
195
+ k: v for k, v in completion_kwargs.items()
196
+ if k not in ["messages", "stop", "grammar", "max_tokens", "tools_to_call_from"]
197
+ }
198
+ max_tokens = (
199
+ kwargs.get("max_tokens")
200
+ or self.max_tokens
201
+ or 1024
202
  )
 
 
 
 
 
 
 
 
 
 
203
 
204
+ provider = LlamaCppPythonProvider(self.llm)
205
+ system_message= completion_kwargs["messages"][0]["content"]
206
+ message= completion_kwargs["messages"].pop()["content"]
207
+
208
+ # Create the agent
209
+ agent = LlamaCppAgent(
210
+ provider,
211
+ system_prompt=f"{system_message}",
212
+ predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
213
+ debug_output=True,
214
+ )
215
+ temperature = 0.7
216
+ top_k=40
217
+ top_p=0.95
218
+ max_tokens=1024
219
+ repeat_penalty=1.1
220
+ settings = provider.get_provider_default_settings()
221
+ settings.temperature = temperature
222
+ settings.top_k = top_k
223
+ settings.top_p = top_p
224
+ settings.max_tokens = max_tokens
225
+ settings.repeat_penalty = repeat_penalty
226
+ settings.stream = False
227
+
228
+ print(len(completion_kwargs["messages"]))
229
+ messages = BasicChatHistory()
230
+ for from_message in completion_kwargs["messages"]:
231
+ if from_message["role"] is MessageRole.USER:
232
+ history_message = {"role": MessageRole.USER, "content": from_message["content"]}
233
+ elif from_message["role"] is MessageRole.SYSTEM:
234
+ history_message = {"role": MessageRole.SYSTEM, "content": from_message["content"]}
235
+ else:
236
+ history_message = {"role": MessageRole.ASSISTANT, "content": from_message["content"]}
237
+ messages.add_message(from_message)
238
+ print("<history>")
239
+ stream = agent.get_chat_response(
240
  message,
241
  llm_sampling_settings=settings,
242
  chat_history=messages,
243
+ returns_streaming_generator=False,
244
+ print_output=True,
245
+
246
+ )
247
+
248
+ content = stream
249
+ message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
250
 
251
+ if tools_to_call_from is not None:
252
+ return super.parse_tool_args_if_needed(message)
253
+ return message
254
+ except Exception as e:
255
+ logging.error(f"Model error: {e}")
256
+ return ChatMessage(role="assistant", content=f"Error: {str(e)}")
257
+
258
+
259
+ model = LlamaCppModel(
260
+ model_path = "models/google_gemma-3-1b-it-Q5_K_M.gguf",
261
+ n_ctx=8192,verbose=False
262
+ )
263
+
264
+ import yaml
265
+ with open("test.yaml", "r") as f:
266
+ prompt = f.read()
267
+
268
+ description="""
269
+ *CPU Rag Example with LlamaCpp*
270
+ Take a few minute.
271
+
272
+ Reference
273
+ - [pull-450](https://github.com/huggingface/smolagents/pull/450)
274
+ - [Gemma-llamacpp](https://huggingface.co/spaces/sitammeur/Gemma-llamacpp)
275
+
276
+ """
277
+ #Tool not support
278
+ agent = CodeAgent(prompt_templates =yaml.safe_load(prompt),model=model, tools=[retriever_tool],max_steps=2,verbosity_level=2,name="AGENT",description=description)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ demo = GradioUI(agent)
281
 
 
282
  if __name__ == "__main__":
283
+ demo.launch()