Akjava's picture
Update app.py
5d6dbe9 verified
raw
history blame
11.9 kB
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
# Importing required libraries
import warnings
warnings.filterwarnings("ignore")
import datasets
import os
import json
import subprocess
import sys
import joblib
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent import MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
import gradio as gr
from huggingface_hub import hf_hub_download
from typing import List, Tuple,Dict,Optional
from logger import logging
from exception import CustomExceptionHandling
from smolagents.gradio_ui import GradioUI
from smolagents import (
CodeAgent,
GoogleSearchTool,
Model,
Tool,
LiteLLMModel,
ToolCallingAgent,
ChatMessage,tool,MessageRole
)
cache_file = "docs_processed.joblib"
if os.path.exists(cache_file):
docs_processed = joblib.load(cache_file)
print("Loaded docs_processed from cache.")
else:
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
joblib.dump(docs_processed, cache_file)
print("Created and saved docs_processed to cache.")
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs,
k=7,
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.invoke(
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content)
for i, doc in enumerate(docs)
]
)
retriever_tool = RetrieverTool(docs_processed)
# Download gguf model files
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
hf_hub_download(
repo_id="bartowski/google_gemma-3-1b-it-GGUF",
filename="google_gemma-3-1b-it-Q6_K.gguf",
local_dir="./models",
)
hf_hub_download(
repo_id="bartowski/google_gemma-3-1b-it-GGUF",
filename="google_gemma-3-1b-it-Q5_K_M.gguf",
local_dir="./models",
)
# Set the title and description
title = "Gemma Llama.cpp"
description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices."""
llm = None
llm_model = None
query_system = """
You are a query rewriter. Your task is to convert a user's question into a concise search query suitable for information retrieval.
The goal is to identify the most important keywords for a search engine.
Here are some examples:
User Question: What is transformer?
Search Query: transformer
User Question: How does a transformer model work in natural language processing?
Search Query: transformer model natural language processing
User Question: What are the advantages of using transformers over recurrent neural networks?
Search Query: transformer vs recurrent neural network advantages
User Question: Explain the attention mechanism in transformers.
Search Query: transformer attention mechanism
User Question: What are the different types of transformer architectures?
Search Query: transformer architectures
User Question: What is the history of the transformer model?
Search Query: transformer model history
"""
def clean_text(text):
cleaned = re.sub(r'[^\x00-\x7F]+', '', text) # Remove non-ASCII chars
cleaned = re.sub(r'[^a-zA-Z0-9_\- ]', '', cleaned) #Then your original rule
return cleaned
def to_query(provider,question):
print(f"<question> = {question}")
print(f"<query sytem> = {query_system}")
try:
query_agent = LlamaCppAgent(
provider,
system_prompt=f"{query_system}",
#system_prompt="you are kind assistant",
predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
debug_output=True,
)
message="""
Now, rewrite the following question:
User Question: %s
Search Query:
"""%question
print("<message>")
print(message)
settings = provider.get_provider_default_settings()
messages = BasicChatHistory()
result = query_agent.get_chat_response(
#query_system+message,
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=False,
print_output=False,
)
return clean_text(result)
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e
def respond(
message: str,
history: List[Tuple[str, str]],
model: str,
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repeat_penalty: float,
):
"""
Respond to a message using the Gemma3 model via Llama.cpp.
Args:
- message (str): The message to respond to.
- history (List[Tuple[str, str]]): The chat history.
- model (str): The model to use.
- system_message (str): The system message to use.
- max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature of the model.
- top_p (float): The top-p of the model.
- top_k (int): The top-k of the model.
- repeat_penalty (float): The repetition penalty of the model.
Returns:
str: The response to the message.
"""
try:
# Load the global variables
global llm
global llm_model
# Load the model
if llm is None or llm_model != model:
llm = Llama(
model_path=f"models/{model}",
flash_attn=False,
n_gpu_layers=0,
n_batch=8,
n_ctx=2048,
n_threads=2,
n_threads_batch=2,
)
llm_model = model
provider = LlamaCppPythonProvider(llm)
query = to_query(provider,message)
print("<query>")
print(f"from {message} to {query}")
text = retriever_tool(query=f"{query}")
retriever_system="""
You are an AI assistant that answers questions based on documents provided by the user. Wait for the user to send a document. Once you receive the document, carefully read its contents and then answer the following question:
Question: %s
Document:
""" % message
retriever_system="""
You are an AI assistant that answers questions based on below retrievered documents.
Documents:
---
%s
---
Question: %s
Answer:
""" % (text,message)
#[Wait for user's document]
# Create the agent
agent = LlamaCppAgent(
provider,
#system_prompt=f"{retriever_system}",
system_prompt="you are kind assistant",
predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
debug_output=True,
)
# Set the settings like temperature, top-k, top-p, max tokens, etc.
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = True
messages = BasicChatHistory()
# Add the chat history
for msn in history:
user = {"role": Roles.user, "content": msn[0]}
assistant = {"role": Roles.assistant, "content": msn[1]}
messages.add_message(user)
messages.add_message(assistant)
# Get the response stream
stream = agent.get_chat_response(
retriever_system,
#retriever_system+text,
#retriever_system+text,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=True,
print_output=False,
)
# Log the success
logging.info("Response stream generated successfully")
# Generate the response
outputs = ""
for output in stream:
outputs += output
yield outputs
# Handle exceptions that may occur during the process
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e
# Create a chat interface
demo = gr.ChatInterface(
respond,
examples=[["What is the Diffuser?"], ["Tell me About Huggingface."], ["How to upload dataset?"]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Dropdown(
choices=[
"google_gemma-3-1b-it-Q6_K.gguf",
"google_gemma-3-1b-it-Q5_K_M.gguf",
],
value="google_gemma-3-1b-it-Q5_K_M.gguf",
label="Model",
info="Select the AI model to use for chat",
),
gr.Textbox(
value="You are a helpful assistant.",
label="System Prompt",
info="Define the AI assistant's personality and behavior",
lines=2,visible=False
),
gr.Slider(
minimum=512,
maximum=2048,
value=1024,
step=1,
label="Max Tokens",
info="Maximum length of response (higher = longer replies)",
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Creativity level (higher = more creative, lower = more focused)",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Nucleus sampling threshold",
),
gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top-k",
info="Limit vocabulary choices to top K tokens",
),
gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty",
info="Penalize repeated words (higher = less repetition)",
),
],
theme="Ocean",
submit_btn="Send",
stop_btn="Stop",
title=title,
description=description,
chatbot=gr.Chatbot(scale=1, show_copy_button=True),
flagging_mode="never",
)
# Launch the chat interface
if __name__ == "__main__":
demo.launch(debug=False)