|
from langchain.docstore.document import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.retrievers import BM25Retriever |
|
|
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
import datasets |
|
import os |
|
import json |
|
import subprocess |
|
import sys |
|
import joblib |
|
from llama_cpp import Llama |
|
from llama_cpp_agent import LlamaCppAgent |
|
from llama_cpp_agent import MessagesFormatterType |
|
from llama_cpp_agent.providers import LlamaCppPythonProvider |
|
from llama_cpp_agent.chat_history import BasicChatHistory |
|
from llama_cpp_agent.chat_history.messages import Roles |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from typing import List, Tuple,Dict,Optional |
|
from logger import logging |
|
from exception import CustomExceptionHandling |
|
|
|
from smolagents.gradio_ui import GradioUI |
|
from smolagents import ( |
|
CodeAgent, |
|
GoogleSearchTool, |
|
Model, |
|
Tool, |
|
LiteLLMModel, |
|
ToolCallingAgent, |
|
ChatMessage,tool,MessageRole |
|
) |
|
|
|
cache_file = "docs_processed.joblib" |
|
if os.path.exists(cache_file): |
|
docs_processed = joblib.load(cache_file) |
|
print("Loaded docs_processed from cache.") |
|
else: |
|
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") |
|
source_docs = [ |
|
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base |
|
] |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=50, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
separators=["\n\n", "\n", ".", " ", ""], |
|
) |
|
docs_processed = text_splitter.split_documents(source_docs) |
|
joblib.dump(docs_processed, cache_file) |
|
print("Created and saved docs_processed to cache.") |
|
|
|
class RetrieverTool(Tool): |
|
name = "retriever" |
|
description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query." |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, docs, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.retriever = BM25Retriever.from_documents( |
|
docs, |
|
k=7, |
|
) |
|
|
|
def forward(self, query: str) -> str: |
|
assert isinstance(query, str), "Your search query must be a string" |
|
|
|
docs = self.retriever.invoke( |
|
query, |
|
) |
|
return "\nRetrieved documents:\n" + "".join( |
|
[ |
|
f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content) |
|
for i, doc in enumerate(docs) |
|
] |
|
) |
|
|
|
retriever_tool = RetrieverTool(docs_processed) |
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
hf_hub_download( |
|
repo_id="bartowski/google_gemma-3-1b-it-GGUF", |
|
filename="google_gemma-3-1b-it-Q6_K.gguf", |
|
local_dir="./models", |
|
) |
|
hf_hub_download( |
|
repo_id="bartowski/google_gemma-3-1b-it-GGUF", |
|
filename="google_gemma-3-1b-it-Q5_K_M.gguf", |
|
local_dir="./models", |
|
) |
|
|
|
|
|
title = "Gemma Llama.cpp" |
|
description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices.""" |
|
|
|
|
|
llm = None |
|
llm_model = None |
|
|
|
|
|
query_system = """ |
|
You are a query rewriter. Your task is to convert a user's question into a concise search query suitable for information retrieval. |
|
The goal is to identify the most important keywords for a search engine. |
|
|
|
Here are some examples: |
|
|
|
User Question: What is transformer? |
|
Search Query: transformer |
|
|
|
User Question: How does a transformer model work in natural language processing? |
|
Search Query: transformer model natural language processing |
|
|
|
User Question: What are the advantages of using transformers over recurrent neural networks? |
|
Search Query: transformer vs recurrent neural network advantages |
|
|
|
User Question: Explain the attention mechanism in transformers. |
|
Search Query: transformer attention mechanism |
|
|
|
User Question: What are the different types of transformer architectures? |
|
Search Query: transformer architectures |
|
|
|
User Question: What is the history of the transformer model? |
|
Search Query: transformer model history |
|
""" |
|
def clean_text(text): |
|
cleaned = re.sub(r'[^\x00-\x7F]+', '', text) |
|
cleaned = re.sub(r'[^a-zA-Z0-9_\- ]', '', cleaned) |
|
return cleaned |
|
def to_query(provider,question): |
|
print(f"<question> = {question}") |
|
print(f"<query sytem> = {query_system}") |
|
try: |
|
query_agent = LlamaCppAgent( |
|
provider, |
|
system_prompt=f"{query_system}", |
|
|
|
predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2, |
|
debug_output=True, |
|
) |
|
|
|
message=""" |
|
Now, rewrite the following question: |
|
User Question: %s |
|
Search Query: |
|
"""%question |
|
|
|
print("<message>") |
|
print(message) |
|
settings = provider.get_provider_default_settings() |
|
messages = BasicChatHistory() |
|
result = query_agent.get_chat_response( |
|
|
|
message, |
|
llm_sampling_settings=settings, |
|
chat_history=messages, |
|
returns_streaming_generator=False, |
|
print_output=False, |
|
) |
|
return clean_text(result) |
|
except Exception as e: |
|
|
|
raise CustomExceptionHandling(e, sys) from e |
|
|
|
def respond( |
|
message: str, |
|
history: List[Tuple[str, str]], |
|
model: str, |
|
system_message: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
top_k: int, |
|
repeat_penalty: float, |
|
): |
|
""" |
|
Respond to a message using the Gemma3 model via Llama.cpp. |
|
|
|
Args: |
|
- message (str): The message to respond to. |
|
- history (List[Tuple[str, str]]): The chat history. |
|
- model (str): The model to use. |
|
- system_message (str): The system message to use. |
|
- max_tokens (int): The maximum number of tokens to generate. |
|
- temperature (float): The temperature of the model. |
|
- top_p (float): The top-p of the model. |
|
- top_k (int): The top-k of the model. |
|
- repeat_penalty (float): The repetition penalty of the model. |
|
|
|
Returns: |
|
str: The response to the message. |
|
""" |
|
try: |
|
|
|
global llm |
|
global llm_model |
|
|
|
|
|
if llm is None or llm_model != model: |
|
llm = Llama( |
|
model_path=f"models/{model}", |
|
flash_attn=False, |
|
n_gpu_layers=0, |
|
n_batch=8, |
|
n_ctx=2048, |
|
n_threads=2, |
|
n_threads_batch=2, |
|
) |
|
llm_model = model |
|
provider = LlamaCppPythonProvider(llm) |
|
|
|
query = to_query(provider,message) |
|
print("<query>") |
|
print(f"from {message} to {query}") |
|
text = retriever_tool(query=f"{query}") |
|
|
|
retriever_system=""" |
|
You are an AI assistant that answers questions based on documents provided by the user. Wait for the user to send a document. Once you receive the document, carefully read its contents and then answer the following question: |
|
|
|
Question: %s |
|
|
|
Document: |
|
""" % message |
|
retriever_system=""" |
|
You are an AI assistant that answers questions based on below retrievered documents. |
|
|
|
Documents: |
|
--- |
|
%s |
|
--- |
|
Question: %s |
|
Answer: |
|
""" % (text,message) |
|
|
|
|
|
|
|
agent = LlamaCppAgent( |
|
provider, |
|
|
|
system_prompt="you are kind assistant", |
|
predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2, |
|
debug_output=True, |
|
) |
|
|
|
|
|
settings = provider.get_provider_default_settings() |
|
settings.temperature = temperature |
|
settings.top_k = top_k |
|
settings.top_p = top_p |
|
settings.max_tokens = max_tokens |
|
settings.repeat_penalty = repeat_penalty |
|
settings.stream = True |
|
|
|
messages = BasicChatHistory() |
|
|
|
|
|
for msn in history: |
|
user = {"role": Roles.user, "content": msn[0]} |
|
assistant = {"role": Roles.assistant, "content": msn[1]} |
|
messages.add_message(user) |
|
messages.add_message(assistant) |
|
|
|
|
|
stream = agent.get_chat_response( |
|
retriever_system, |
|
|
|
|
|
llm_sampling_settings=settings, |
|
chat_history=messages, |
|
returns_streaming_generator=True, |
|
print_output=False, |
|
) |
|
|
|
|
|
logging.info("Response stream generated successfully") |
|
|
|
|
|
outputs = "" |
|
for output in stream: |
|
outputs += output |
|
yield outputs |
|
|
|
|
|
except Exception as e: |
|
|
|
raise CustomExceptionHandling(e, sys) from e |
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
examples=[["What is the Diffuser?"], ["Tell me About Huggingface."], ["How to upload dataset?"]], |
|
additional_inputs_accordion=gr.Accordion( |
|
label="⚙️ Parameters", open=False, render=False |
|
), |
|
additional_inputs=[ |
|
gr.Dropdown( |
|
choices=[ |
|
"google_gemma-3-1b-it-Q6_K.gguf", |
|
"google_gemma-3-1b-it-Q5_K_M.gguf", |
|
], |
|
value="google_gemma-3-1b-it-Q5_K_M.gguf", |
|
label="Model", |
|
info="Select the AI model to use for chat", |
|
), |
|
gr.Textbox( |
|
value="You are a helpful assistant.", |
|
label="System Prompt", |
|
info="Define the AI assistant's personality and behavior", |
|
lines=2,visible=False |
|
), |
|
gr.Slider( |
|
minimum=512, |
|
maximum=2048, |
|
value=1024, |
|
step=1, |
|
label="Max Tokens", |
|
info="Maximum length of response (higher = longer replies)", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
info="Creativity level (higher = more creative, lower = more focused)", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p", |
|
info="Nucleus sampling threshold", |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=40, |
|
step=1, |
|
label="Top-k", |
|
info="Limit vocabulary choices to top K tokens", |
|
), |
|
gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.1, |
|
label="Repetition Penalty", |
|
info="Penalize repeated words (higher = less repetition)", |
|
), |
|
], |
|
theme="Ocean", |
|
submit_btn="Send", |
|
stop_btn="Stop", |
|
title=title, |
|
description=description, |
|
chatbot=gr.Chatbot(scale=1, show_copy_button=True), |
|
flagging_mode="never", |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=False) |
|
|