Akjava's picture
Update app.py
6145b5c verified
raw
history blame
13.5 kB
# Importing required libraries
import warnings
warnings.filterwarnings("ignore")
import os
import json
import subprocess
import sys
from llama_cpp import Llama,llama_model_decoder_start_token
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent import MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers
import gradio as gr
from huggingface_hub import hf_hub_download
from typing import List, Tuple
from logger import logging
from exception import CustomExceptionHandling
# Download gguf model files
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
os.makedirs("models",exist_ok=True)
#mtsdurica/madlad400-3b-mt-Q8_0-GGUF
hf_hub_download(
repo_id="mtsdurica/madlad400-3b-mt-Q8_0-GGUF",
filename="madlad400-3b-mt-q8_0.gguf",
local_dir="./models",
)
# Define the prompt markers for Gemma 3
gemma_3_prompt_markers = {
Roles.system: PromptMarkers("", "\n"), # System prompt should be included within user message
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"),
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"),
Roles.tool: PromptMarkers("", ""), # If you need tool support
}
# Create the formatter
gemma_3_formatter = MessagesFormatter(
pre_prompt="", # No pre-prompt
prompt_markers=gemma_3_prompt_markers,
include_sys_prompt_in_first_user_message=True, # Include system prompt in first user message
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"],
strip_prompt=False, # Don't strip whitespace from the prompt
bos_token="<bos>", # Beginning of sequence token for Gemma 3
eos_token="<eos>", # End of sequence token for Gemma 3
)
# Set the title and description
title = "Gemma Llama.cpp"
description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices."""
llm = None
llm_model = None
import ctypes
import os
import multiprocessing
import llama_cpp
def test():
llama_cpp.llama_backend_init(numa=False)
N_THREADS = multiprocessing.cpu_count()
MODEL_PATH = "models/madlad400-3b-mt-q8_0.gguf"
prompt = b"translate English to German: The house is wonderful."
lparams = llama_cpp.llama_model_default_params()
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode("utf-8"), lparams)
vocab = llama_cpp.llama_model_get_vocab(model)
cparams = llama_cpp.llama_context_default_params()
cparams.no_perf = False
ctx = llama_cpp.llama_init_from_model(model, cparams)
sparams = llama_cpp.llama_sampler_chain_default_params()
smpl = llama_cpp.llama_sampler_chain_init(sparams)
llama_cpp.llama_sampler_chain_add(smpl, llama_cpp.llama_sampler_init_greedy())
n_past = 0
embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))()
n_of_tok = llama_cpp.llama_tokenize(
vocab,
prompt,
len(prompt),
embd_inp,
len(embd_inp),
True,
True,
)
embd_inp = embd_inp[:n_of_tok]
n_ctx = llama_cpp.llama_n_ctx(ctx)
n_predict = 20
n_predict = min(n_predict, n_ctx - len(embd_inp))
input_consumed = 0
input_noecho = False
remaining_tokens = n_predict
embd = []
last_n_size = 64
last_n_tokens_data = [0] * last_n_size
n_batch = 24
last_n_repeat = 64
repeat_penalty = 1
frequency_penalty = 0.0
presence_penalty = 0.0
batch = llama_cpp.llama_batch_init(n_batch, 0, 1)
# prepare batch for encoding containing the prompt
batch.n_tokens = len(embd_inp)
for i in range(batch.n_tokens):
batch.token[i] = embd_inp[i]
batch.pos[i] = i
batch.n_seq_id[i] = 1
batch.seq_id[i][0] = 0
batch.logits[i] = False
llama_cpp.llama_encode(
ctx,
batch
)
# now overwrite embd_inp so batch for decoding will initially contain only
# a single token with id acquired from llama_model_decoder_start_token(model)
embd_inp = [llama_cpp.llama_model_decoder_start_token(model)]
while remaining_tokens > 0:
if len(embd) > 0:
batch.n_tokens = len(embd)
for i in range(batch.n_tokens):
batch.token[i] = embd[i]
batch.pos[i] = n_past + i
batch.n_seq_id[i] = 1
batch.seq_id[i][0] = 0
batch.logits[i] = i == batch.n_tokens - 1
llama_cpp.llama_decode(
ctx,
batch
)
n_past += len(embd)
embd = []
if len(embd_inp) <= input_consumed:
id = llama_cpp.llama_sampler_sample(smpl, ctx, -1)
last_n_tokens_data = last_n_tokens_data[1:] + [id]
embd.append(id)
input_noecho = False
remaining_tokens -= 1
else:
while len(embd_inp) > input_consumed:
embd.append(embd_inp[input_consumed])
last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]]
input_consumed += 1
if len(embd) >= n_batch:
break
if not input_noecho:
for id in embd:
size = 32
buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece(
vocab, llama_cpp.llama_token(id), buffer, size, 0, True
)
assert n <= size
print(
buffer[:n].decode("utf-8"),
end="",
flush=True,
)
if len(embd) > 0 and embd[-1] in [llama_cpp.llama_token_eos(vocab), llama_cpp.llama_token_eot(vocab)]:
break
print()
def trans(text):
test()
yield "done"
# テキストに言語タグを付与し、バイト列に変換
input_text = f"<2ja>{text}".encode('utf-8')
# トークナイズ
tokens = llm.tokenize(input_text)
print("Tokens:", tokens)
# BOSトークンを取得し、確認
bos_token = llm.token_bos()
print("BOS Token:", bos_token)
initial_tokens = [bos_token]
initial_tokens = [1]
print("Initial Tokens:", initial_tokens)
# 生成
buf = ""
for token in llm.generate(initial_tokens, top_p=0.95, temp=0.0, repeat_penalty=1.0):
decoded = llm.detokenize([token]).decode('utf-8', errors='ignore')
buf += decoded
if token == llm.token_eos():
break
return buf
# テキストに言語タグを付与し、バイト列に変換
input_text = f"<2ja>{text}".encode('utf-8')
# トークナイズ
tokens = llm.tokenize(input_text)
print("Tokens:", tokens)
# BOSトークンを使用(デコーダーのみのモデルを想定)
initial_tokens = [llm.token_bos()]
# 生成
buf = ""
for token in llm.generate(initial_tokens, top_p=0.95, temp=0.0, repeat_penalty=1.0):
decoded = llm.detokenize([token]).decode('utf-8', errors='ignore')
buf += decoded
if token == llm.token_eos():
break
return buf
input_text = f"<2ja>{text}".encode('utf-8')
tokens = llm.tokenize(input_text)
print("Tokens:", tokens)
initial_tokens = [llm.decoder_start_token()]
print("Initial Tokens:", initial_tokens)
return text
llama = llm
text = f"<2ja>{text}".encode()
tokens = llama.tokenize(text)
llama.encode(tokens)
tokens = [llama.decoder_start_token()]
buf = ""
for token in llama.generate(tokens, top_k=0, top_p=0.95, temp=0, repeat_penalty=1.0):
buf += llama.detokenize([token]).decode()
if token == llama.token_eos():
break
return buf
def respond(
message: str,
history: List[Tuple[str, str]],
model: str,
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repeat_penalty: float,
):
"""
Respond to a message using the Gemma3 model via Llama.cpp.
Args:
- message (str): The message to respond to.
- history (List[Tuple[str, str]]): The chat history.
- model (str): The model to use.
- system_message (str): The system message to use.
- max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature of the model.
- top_p (float): The top-p of the model.
- top_k (int): The top-k of the model.
- repeat_penalty (float): The repetition penalty of the model.
Returns:
str: The response to the message.
"""
try:
# Load the global variables
global llm
global llm_model
#llama = Llama("madlad400-3b-mt-q8_0.gguf")
# Load the model
if llm is None or llm_model != model:
llm = Llama(
model_path=f"models/{model}",
flash_attn=False,
n_gpu_layers=0,
n_batch=8,
n_ctx=2048,
n_threads=8,
n_threads_batch=8,
)
llm_model = model
return trans(message)
provider = LlamaCppPythonProvider(llm)
# Create the agent
agent = LlamaCppAgent(
provider,
system_prompt=f"{system_message}",
# predefined_messages_formatter_type=GEMMA_2,
custom_messages_formatter=gemma_3_formatter,
debug_output=True,
)
# Set the settings like temperature, top-k, top-p, max tokens, etc.
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = True
messages = BasicChatHistory()
# Add the chat history
for msn in history:
user = {"role": Roles.user, "content": msn[0]}
assistant = {"role": Roles.assistant, "content": msn[1]}
messages.add_message(user)
messages.add_message(assistant)
# Get the response stream
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=True,
print_output=False,
)
# Log the success
logging.info("Response stream generated successfully")
# Generate the response
outputs = ""
for output in stream:
outputs += output
yield outputs
# Handle exceptions that may occur during the process
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e
# Create a chat interface
demo = gr.ChatInterface(
respond,
examples=[["What is the capital of France?"], ["Tell me something about artificial intelligence."], ["What is gravity?"]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Dropdown(
choices=[
"madlad400-3b-mt-q8_0.gguf",
],
value="madlad400-3b-mt-q8_0.gguf",
label="Model",
info="Select the AI model to use for chat",
),
gr.Textbox(
value="You are a helpful assistant.",
label="System Prompt",
info="Define the AI assistant's personality and behavior",
lines=2,
),
gr.Slider(
minimum=512,
maximum=2048,
value=1024,
step=1,
label="Max Tokens",
info="Maximum length of response (higher = longer replies)",
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Creativity level (higher = more creative, lower = more focused)",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Nucleus sampling threshold",
),
gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top-k",
info="Limit vocabulary choices to top K tokens",
),
gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty",
info="Penalize repeated words (higher = less repetition)",
),
],
theme="Ocean",
submit_btn="Send",
stop_btn="Stop",
title=title,
description=description,
chatbot=gr.Chatbot(scale=1, show_copy_button=True),
flagging_mode="never",
)
# Launch the chat interface
if __name__ == "__main__":
demo.launch(debug=False)
test()