Akjava's picture
Update app.py
b3716da verified
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
import re
# Importing required libraries
import warnings
warnings.filterwarnings("ignore")
import datasets
import os
import json
import subprocess
import sys
import joblib
from llama_cpp import Llama
import gradio as gr
from huggingface_hub import hf_hub_download
from typing import List, Tuple,Dict,Optional
from logger import logging
from exception import CustomExceptionHandling
cache_file = "docs_processed.joblib"
if os.path.exists(cache_file):
docs_processed = joblib.load(cache_file)
#print("Loaded docs_processed from cache.")
else:
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
joblib.dump(docs_processed, cache_file)
print("Created and saved docs_processed to cache.")
class RetrieverTool():
name = "retriever"
description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
#super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs,
k=7,
)
def __call__(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.invoke(
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content)
for i, doc in enumerate(docs)
]
)
retriever_tool = RetrieverTool(docs_processed)
# Download gguf model files
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
hf_hub_download(
repo_id="mradermacher/Qwen2.5-0.5B-Rag-Thinking-i1-GGUF",
filename="Qwen2.5-0.5B-Rag-Thinking.i1-Q6_K.gguf",
local_dir="./models",
)
t5_size="base"
hf_hub_download(
repo_id=f"Felladrin/gguf-flan-t5-{t5_size}",
filename=f"flan-t5-{t5_size}.Q8_0.gguf",
local_dir="./models",
)
query_system = """
You are a query rewriter. Your task is to convert a user's question into a concise search query suitable for information retrieval.
The goal is to identify the most important keywords for a search engine.
Here are some examples:
User Question: What is transformer?
Search Query: transformer
User Question: How does a transformer model work in natural language processing?
Search Query: transformer model natural language processing
User Question: What are the advantages of using transformers over recurrent neural networks?
Search Query: transformer vs recurrent neural network advantages
User Question: Explain the attention mechanism in transformers.
Search Query: transformer attention mechanism
User Question: What are the different types of transformer architectures?
Search Query: transformer architectures
User Question: What is the history of the transformer model?
Search Query: transformer model history
"""
# remove strange char like *,/
def clean_text(text):
cleaned = re.sub(r'[^\x00-\x7F]+', '', text) # Remove non-ASCII chars
cleaned = re.sub(r'[^a-zA-Z0-9_\- ]', '', cleaned) #Then your original rule
cleaned = cleaned.replace("---","")
return cleaned
def generate_t5(llama,message):#text size must be smaller than ctx(default=512)
if llama == None:
raise ValueError("llama not initialized")
try:
tokens = llama.tokenize(f"{message}".encode("utf-8"))
#print(f"text length={len(tokens)}")
llama.encode(tokens)
tokens = [llama.decoder_start_token()]
outputs =""
iteration = 1
temperature = 0.5
top_k = 40
top_p = 0.95
repeat_penalty = 1.2
for i in range(iteration):
for token in llama.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repeat_penalty):
outputs+= llama.detokenize([token]).decode()
if token == llama.token_eos():
break
return outputs
except Exception as e:
raise CustomExceptionHandling(e, sys) from e
return None
llama = None
def to_query(question):
system = """
You are a query rewriter. Your task is to convert a user's question into a concise search query suitable for information retrieval.
The goal is to identify the most important keywords for a search engine.
Here are some examples:
User Question: What is transformer?
Search Query: transformer
User Question: How does a transformer model work in natural language processing?
Search Query: transformer model natural language processing
User Question: What are the advantages of using transformers over recurrent neural networks?
Search Query: transformer vs recurrent neural network advantages
User Question: Explain the attention mechanism in transformers.
Search Query: transformer attention mechanism
User Question: What are the different types of transformer architectures?
Search Query: transformer architectures
User Question: What is the history of the transformer model?
Search Query: transformer model history
---
Now, rewrite the following question:
User Question: %s
Search Query:
"""% question
message = system
try:
global llama
if llama == None:
model_id = f"flan-t5-{t5_size}.Q8_0.gguf"
llama = Llama(f"models/{model_id}",flash_attn=False,verbose=False,
n_gpu_layers=0,
n_threads=2,
n_threads_batch=2
)
query = generate_t5(llama,message)
return clean_text(query)
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e
return None
qwen_prompt = """<|im_start|>system
You answer questions from the user, always using the context provided as a basis.
Write down your reasoning for answering the question, between the <think> and </think> tags.<|im_end|>
<|im_start|>user
Context:
%s
Question:
%s<|im_end|>
<|im_start|>assistant
<think>"""
def answer(document:str,question:str,model:str="Qwen2.5-0.5B-Rag-Thinking.i1-Q6_K.gguf")->str:
global llm
global llm_model
global provider
llm = Llama(
model_path=f"models/{model}",
flash_attn=False,
n_gpu_layers=0,
n_batch=1024,
n_ctx=2048*4,
n_threads=2,
n_threads_batch=2,
verbose=False
)
llm_model = model
def respond(
message: str,
history: List[Tuple[str, str]],
model: str,
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repeat_penalty: float,
):
"""
Respond to a message using the Gemma3 model via Llama.cpp.
Args:
- message (str): The message to respond to.
- history (List[Tuple[str, str]]): The chat history.
- model (str): The model to use.
- system_message (str): The system message to use.
- max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature of the model.
- top_p (float): The top-p of the model.
- top_k (int): The top-k of the model.
- repeat_penalty (float): The repetition penalty of the model.
Returns:
str: The response to the message.
"""
if model is None:#
return
query = to_query(message)
document = retriever_tool(query=query)
#print(document)
answer(document,message)
response = ""
#do direct in here
for chunk in llm(system_message%(document,message),max_tokens=max_tokens,stream=True,top_k=top_k, top_p=top_p, temperature=temperature, repeat_penalty=repeat_penalty):
text = chunk['choices'][0]['text']
response += text
yield response
# Create a chat interface
# Set the title and description
title = "llama.cpp Qwen2.5-0.5B-Rag-Thinking-Flan-T5"
description = """
- I use forked [llama-cpp-python](https://github.com/fairydreaming/llama-cpp-python/tree/t5) which support T5 on server and it's doesn't support new models(like gemma3)
- Search query generation(query reformulation) Tasks - I use flan-t5-base (large make better result,but too large for just this task)
- Qwen2.5-0.5B as good as small-size.
- anyway google T5 series on CPU is amazing
## Huggingface Free CPU Limitations
- When duplicating a space, the build process can occasionally become stuck, requiring a manual restart to finish.
- Spaces may unexpectedly stop functioning or even be deleted, leading to the need to rework them. Refer to [issue](https://github.com/huggingface/hub-docs/issues/1633) for more information.
"""
demo = gr.ChatInterface(
respond,
examples=[["What is the Diffuser?"], ["Tell me About Huggingface."], ["How to upload dataset?"]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Dropdown(
choices=[
"Qwen2.5-0.5B-Rag-Thinking.i1-Q6_K.gguf",
],
value="Qwen2.5-0.5B-Rag-Thinking.i1-Q6_K.gguf",
label="Model",
info="Select the AI model to use for chat",visible=False
),
gr.Textbox(
value=qwen_prompt,
label="System Prompt",
info="Define the AI assistant's personality and behavior",
lines=2,visible=True
),
gr.Slider(
minimum=1024,
maximum=8192,
value=2048,
step=1,
label="Max Tokens",
info="Maximum length of response (higher = longer replies)",
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Creativity level (higher = more creative, lower = more focused)",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Nucleus sampling threshold",
),
gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top-k",
info="Limit vocabulary choices to top K tokens",
),
gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty",
info="Penalize repeated words (higher = less repetition)",
),
],
theme="Ocean",
submit_btn="Send",
stop_btn="Stop",
title=title,
description=description,
chatbot=gr.Chatbot(scale=1, show_copy_button=True),
flagging_mode="never",
)
# Launch the chat interface
if __name__ == "__main__":
demo.launch(debug=False)