Akjava's picture
Update app.py
4d69435 verified
raw
history blame
9.27 kB
# Importing required libraries
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
import warnings
warnings.filterwarnings("ignore")
import datasets
import os
import json
import subprocess
import sys
import joblib
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent import MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings
import gradio as gr
from huggingface_hub import hf_hub_download
from typing import List, Tuple,Dict,Optional
from logger import logging
from exception import CustomExceptionHandling
from smolagents.gradio_ui import GradioUI
from smolagents import (
CodeAgent,
GoogleSearchTool,
Model,
Tool,
LiteLLMModel,
ToolCallingAgent,
ChatMessage,tool,MessageRole
)
cache_file = "docs_processed.joblib"
if os.path.exists(cache_file):
docs_processed = joblib.load(cache_file)
print("Loaded docs_processed from cache.")
else:
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400,
chunk_overlap=20,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
joblib.dump(docs_processed, cache_file)
print("Created and saved docs_processed to cache.")
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs,
k=7,
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.invoke(
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content)
for i, doc in enumerate(docs)
]
)
# Download gguf model files
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
os.makedirs("models",exist_ok=True)
logging.info("start download")
hf_hub_download(
repo_id="bartowski/google_gemma-3-1b-it-GGUF",
filename="google_gemma-3-1b-it-Q5_K_M.gguf",
local_dir="./models",
)
retriever_tool = RetrieverTool(docs_processed)
# based https://github.com/huggingface/smolagents/pull/450
# almost overwrite with https://huggingface.co/spaces/sitammeur/Gemma-llamacpp
class LlamaCppModel(Model):
def __init__(
self,
model_path: Optional[str] = None,
repo_id: Optional[str] = None,
filename: Optional[str] = None,
n_gpu_layers: int = 0,
n_ctx: int = 8192,
max_tokens: int = 1024,
verbose:bool = False,
**kwargs,
):
"""
Initializes the LlamaCppModel.
Parameters:
model_path (str, optional): Path to the local model file.
repo_id (str, optional): Hugging Face repository ID if loading from Hugging Face.
filename (str, optional): Specific filename to load from the repository.
n_gpu_layers (int, default=0): Number of GPU layers to use.
n_ctx (int, default=8192): Context size for the model.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If neither model_path nor repo_id+filename are provided.
"""
from llama_cpp import Llama
print("init2")
super().__init__(**kwargs)
self.flatten_messages_as_text=True
self.max_tokens = max_tokens
if model_path:
self.llm = Llama(
model_path=model_path,
flash_attn=False,
n_gpu_layers=0,
n_batch=8,
n_ctx=n_ctx,
n_threads=2,
n_threads_batch=2,verbose=False
)
elif repo_id and filename:
self.llm = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
n_gpu_layers=n_gpu_layers,
n_ctx=n_ctx,
max_tokens=max_tokens,
verbose=verbose,
**kwargs
)
else:
raise ValueError("Must provide either model_path or repo_id+filename")
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
from llama_cpp import LlamaGrammar
try:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs
)
if not tools_to_call_from:
completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)
filtered_kwargs = {
k: v for k, v in completion_kwargs.items()
if k not in ["messages", "stop", "grammar", "max_tokens", "tools_to_call_from"]
}
max_tokens = (
kwargs.get("max_tokens")
or self.max_tokens
or 1024
)
provider = LlamaCppPythonProvider(self.llm)
system_message= completion_kwargs["messages"][0]["content"]
message= completion_kwargs["messages"].pop()["content"]
# Create the agent
agent = LlamaCppAgent(
provider,
system_prompt=f"{system_message}",
predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
debug_output=True,
)
temperature = 0.7
top_k=40
top_p=0.95
max_tokens=1024
repeat_penalty=1.1
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = False
messages = BasicChatHistory()
for from_message in completion_kwargs["messages"]:
if from_message["role"] is MessageRole.USER:
history_message = {"role": MessageRole.USER, "content": from_message["content"]}
elif from_message["role"] is MessageRole.SYSTEM:
history_message = {"role": MessageRole.SYSTEM, "content": from_message["content"]}
else:
history_message = {"role": MessageRole.ASSISTANT, "content": from_message["content"]}
messages.add_message(from_message)
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=False,
print_output=True,
)
content = stream
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
if tools_to_call_from is not None:
return super.parse_tool_args_if_needed(message)
return message
except Exception as e:
logging.error(f"Model error: {e}")
return ChatMessage(role="assistant", content=f"Error: {str(e)}")
model = LlamaCppModel(
model_path = "models/google_gemma-3-1b-it-Q5_K_M.gguf",
n_ctx=8192,verbose=False
)
import yaml
with open("test.yaml", "r") as f:
prompt = f.read()
description="""
*CPU Rag Example with LlamaCpp*
Take a few minute.
Reference
- [pull-450](https://github.com/huggingface/smolagents/pull/450)
- [Gemma-llamacpp](https://huggingface.co/spaces/sitammeur/Gemma-llamacpp)
"""
#Tool not support
agent = CodeAgent(prompt_templates =yaml.safe_load(prompt),model=model, tools=[retriever_tool],max_steps=2,verbosity_level=2,name="AGENT",description=description)
demo = GradioUI(agent)
if __name__ == "__main__":
demo.launch()