Akjava's picture
Update app.py
ce86b70 verified
# "Since it's an almost example, it probably won't be affected by a license."
# Importing required libraries
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
import warnings
warnings.filterwarnings("ignore")
import datasets
import os
import json
import subprocess
import sys
import joblib
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent import MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers
import gradio as gr
from huggingface_hub import hf_hub_download
from typing import List, Tuple,Dict,Optional
from logger import logging
from exception import CustomExceptionHandling
from smolagents.gradio_ui import GradioUI
from smolagents import (
CodeAgent,
GoogleSearchTool,
Model,
Tool,
LiteLLMModel,
ToolCallingAgent,
ChatMessage,tool,MessageRole
)
cache_file = "docs_processed.joblib"
if os.path.exists(cache_file):
docs_processed = joblib.load(cache_file)
print("Loaded docs_processed from cache.")
else:
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400,
chunk_overlap=20,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
joblib.dump(docs_processed, cache_file)
print("Created and saved docs_processed to cache.")
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs,
k=7,
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.invoke(
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content)
for i, doc in enumerate(docs)
]
)
# Download gguf model files
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
os.makedirs("models",exist_ok=True)
logging.info("start download")
hf_hub_download(
repo_id="bartowski/google_gemma-3-4b-it-GGUF",
filename="google_gemma-3-4b-it-Q4_K_M.gguf",
local_dir="./models",
)
retriever_tool = RetrieverTool(docs_processed)
# Define the prompt markers for Gemma 3
gemma_3_prompt_markers = {
Roles.system: PromptMarkers("", "\n"), # System prompt should be included within user message
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"),
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"),
Roles.tool: PromptMarkers("", ""), # If you need tool support
}
# Create the formatter
gemma_3_formatter = MessagesFormatter(
pre_prompt="", # No pre-prompt
prompt_markers=gemma_3_prompt_markers,
include_sys_prompt_in_first_user_message=True, # Include system prompt in first user message
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"],
strip_prompt=False, # Don't strip whitespace from the prompt
bos_token="<bos>", # Beginning of sequence token for Gemma 3
eos_token="<eos>", # End of sequence token for Gemma 3
)
# based https://github.com/huggingface/smolagents/pull/450
# almost overwrite with https://huggingface.co/spaces/sitammeur/Gemma-llamacpp
class LlamaCppModel(Model):
def __init__(
self,
model_path: Optional[str] = None,
repo_id: Optional[str] = None,
filename: Optional[str] = None,
n_gpu_layers: int = 0,
n_ctx: int = 8192,
max_tokens: int = 1024,
verbose:bool = False,
**kwargs,
):
"""
Initializes the LlamaCppModel.
Parameters:
model_path (str, optional): Path to the local model file.
repo_id (str, optional): Hugging Face repository ID if loading from Hugging Face.
filename (str, optional): Specific filename to load from the repository.
n_gpu_layers (int, default=0): Number of GPU layers to use.
n_ctx (int, default=8192): Context size for the model.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If neither model_path nor repo_id+filename are provided.
"""
from llama_cpp import Llama
super().__init__(**kwargs)
self.flatten_messages_as_text=True
self.max_tokens = max_tokens
if model_path:
self.llm = Llama(
model_path=model_path,
flash_attn=False,
n_gpu_layers=0,
#n_batch=1024,
n_ctx=n_ctx,
n_threads=2,
n_threads_batch=2,verbose=False
)
elif repo_id and filename:
self.llm = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
n_gpu_layers=n_gpu_layers,
n_ctx=n_ctx,
max_tokens=max_tokens,
verbose=verbose,
**kwargs
)
else:
raise ValueError("Must provide either model_path or repo_id+filename")
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
from llama_cpp import LlamaGrammar
try:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs
)
if not tools_to_call_from:
completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)
filtered_kwargs = {
k: v for k, v in completion_kwargs.items()
if k not in ["messages", "stop", "grammar", "max_tokens", "tools_to_call_from"]
}
max_tokens = (
kwargs.get("max_tokens")
or self.max_tokens
or 1024
)
provider = LlamaCppPythonProvider(self.llm)
system_message= completion_kwargs["messages"][0]["content"]
message= completion_kwargs["messages"].pop()["content"]
# Create the agent
agent = LlamaCppAgent(
provider,
system_prompt=f"{system_message}",
custom_messages_formatter=gemma_3_formatter,
debug_output=True,
)
temperature = 0.5
top_k=40
top_p=0.95
max_tokens=2048
repeat_penalty=1.1
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = False
messages = BasicChatHistory()
for from_message in completion_kwargs["messages"]:
if from_message["role"] is MessageRole.USER:
history_message = {"role": MessageRole.USER, "content": from_message["content"]}
elif from_message["role"] is MessageRole.SYSTEM:
history_message = {"role": MessageRole.SYSTEM, "content": from_message["content"]}
else:
history_message = {"role": MessageRole.ASSISTANT, "content": from_message["content"]}
messages.add_message(from_message)
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=False,
print_output=False,
)
content = stream
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
if tools_to_call_from is not None:
return super.parse_tool_args_if_needed(message)
return message
except Exception as e:
logging.error(f"Model error: {e}")
return ChatMessage(role="assistant", content=f"Error: {str(e)}")
model = LlamaCppModel(
model_path = "models/google_gemma-3-4b-it-Q4_K_M.gguf",
n_ctx=8192,verbose=False
)
import yaml
with open("retriever.yaml", "r") as f:
prompt = f.read()
description="""
*CPU Rag Example with LlamaCpp*
Take a few minute.customized prompt is the key.
Reference
- [Qwen2.5-0.5B-Rag-Thinking](https://huggingface.co/spaces/Akjava/Qwen2.5-0.5B-Rag-Thinking-Flan-T5)
- [smolagents pull-450](https://github.com/huggingface/smolagents/pull/450)
- [Gemma-llamacpp](https://huggingface.co/spaces/sitammeur/Gemma-llamacpp)
- [Dataset(m-ric/huggingface_doc)](https://huggingface.co/datasets/m-ric/huggingface_doc)
"""
agent = CodeAgent(prompt_templates =yaml.safe_load(prompt),model=model, tools=[retriever_tool],max_steps=1,verbosity_level=0,name="AGENT",description=description)
demo = GradioUI(agent)
if __name__ == "__main__":
demo.launch()