|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import json |
|
import subprocess |
|
import sys |
|
from llama_cpp import Llama,llama_model_decoder_start_token |
|
from llama_cpp_agent import LlamaCppAgent |
|
from llama_cpp_agent import MessagesFormatterType |
|
from llama_cpp_agent.providers import LlamaCppPythonProvider |
|
from llama_cpp_agent.chat_history import BasicChatHistory |
|
from llama_cpp_agent.chat_history.messages import Roles |
|
from llama_cpp_agent.chat_history.messages import Roles |
|
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from typing import List, Tuple |
|
from logger import logging |
|
from exception import CustomExceptionHandling |
|
|
|
|
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
os.makedirs("models",exist_ok=True) |
|
|
|
hf_hub_download( |
|
repo_id="mtsdurica/madlad400-3b-mt-Q8_0-GGUF", |
|
filename="madlad400-3b-mt-q8_0.gguf", |
|
local_dir="./models", |
|
) |
|
|
|
|
|
|
|
gemma_3_prompt_markers = { |
|
Roles.system: PromptMarkers("", "\n"), |
|
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"), |
|
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"), |
|
Roles.tool: PromptMarkers("", ""), |
|
} |
|
|
|
|
|
gemma_3_formatter = MessagesFormatter( |
|
pre_prompt="", |
|
prompt_markers=gemma_3_prompt_markers, |
|
include_sys_prompt_in_first_user_message=True, |
|
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"], |
|
strip_prompt=False, |
|
bos_token="<bos>", |
|
eos_token="<eos>", |
|
) |
|
|
|
|
|
|
|
title = "Gemma Llama.cpp" |
|
description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices.""" |
|
|
|
|
|
llm = None |
|
llm_model = None |
|
|
|
import ctypes |
|
import os |
|
import multiprocessing |
|
|
|
import llama_cpp |
|
|
|
def test(): |
|
|
|
|
|
llama_cpp.llama_backend_init(numa=False) |
|
|
|
N_THREADS = multiprocessing.cpu_count() |
|
MODEL_PATH = "models/madlad400-3b-mt-q8_0.gguf" |
|
|
|
prompt = b"translate English to German: The house is wonderful." |
|
|
|
lparams = llama_cpp.llama_model_default_params() |
|
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode("utf-8"), lparams) |
|
|
|
vocab = llama_cpp.llama_model_get_vocab(model) |
|
|
|
cparams = llama_cpp.llama_context_default_params() |
|
cparams.no_perf = False |
|
ctx = llama_cpp.llama_init_from_model(model, cparams) |
|
|
|
sparams = llama_cpp.llama_sampler_chain_default_params() |
|
smpl = llama_cpp.llama_sampler_chain_init(sparams) |
|
llama_cpp.llama_sampler_chain_add(smpl, llama_cpp.llama_sampler_init_greedy()) |
|
|
|
n_past = 0 |
|
|
|
embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))() |
|
|
|
n_of_tok = llama_cpp.llama_tokenize( |
|
vocab, |
|
prompt, |
|
len(prompt), |
|
embd_inp, |
|
len(embd_inp), |
|
True, |
|
True, |
|
) |
|
|
|
embd_inp = embd_inp[:n_of_tok] |
|
|
|
n_ctx = llama_cpp.llama_n_ctx(ctx) |
|
|
|
n_predict = 20 |
|
n_predict = min(n_predict, n_ctx - len(embd_inp)) |
|
|
|
input_consumed = 0 |
|
input_noecho = False |
|
|
|
remaining_tokens = n_predict |
|
|
|
embd = [] |
|
last_n_size = 64 |
|
last_n_tokens_data = [0] * last_n_size |
|
n_batch = 24 |
|
last_n_repeat = 64 |
|
repeat_penalty = 1 |
|
frequency_penalty = 0.0 |
|
presence_penalty = 0.0 |
|
|
|
batch = llama_cpp.llama_batch_init(n_batch, 0, 1) |
|
|
|
|
|
batch.n_tokens = len(embd_inp) |
|
for i in range(batch.n_tokens): |
|
batch.token[i] = embd_inp[i] |
|
batch.pos[i] = i |
|
batch.n_seq_id[i] = 1 |
|
batch.seq_id[i][0] = 0 |
|
batch.logits[i] = False |
|
|
|
llama_cpp.llama_encode( |
|
ctx, |
|
batch |
|
) |
|
|
|
|
|
|
|
embd_inp = [llama_cpp.llama_model_decoder_start_token(model)] |
|
|
|
while remaining_tokens > 0: |
|
if len(embd) > 0: |
|
|
|
batch.n_tokens = len(embd) |
|
for i in range(batch.n_tokens): |
|
batch.token[i] = embd[i] |
|
batch.pos[i] = n_past + i |
|
batch.n_seq_id[i] = 1 |
|
batch.seq_id[i][0] = 0 |
|
batch.logits[i] = i == batch.n_tokens - 1 |
|
|
|
llama_cpp.llama_decode( |
|
ctx, |
|
batch |
|
) |
|
|
|
n_past += len(embd) |
|
embd = [] |
|
if len(embd_inp) <= input_consumed: |
|
id = llama_cpp.llama_sampler_sample(smpl, ctx, -1) |
|
|
|
last_n_tokens_data = last_n_tokens_data[1:] + [id] |
|
embd.append(id) |
|
input_noecho = False |
|
remaining_tokens -= 1 |
|
else: |
|
while len(embd_inp) > input_consumed: |
|
embd.append(embd_inp[input_consumed]) |
|
last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]] |
|
input_consumed += 1 |
|
if len(embd) >= n_batch: |
|
break |
|
if not input_noecho: |
|
for id in embd: |
|
size = 32 |
|
buffer = (ctypes.c_char * size)() |
|
n = llama_cpp.llama_token_to_piece( |
|
vocab, llama_cpp.llama_token(id), buffer, size, 0, True |
|
) |
|
assert n <= size |
|
print( |
|
buffer[:n].decode("utf-8"), |
|
end="", |
|
flush=True, |
|
) |
|
|
|
if len(embd) > 0 and embd[-1] in [llama_cpp.llama_token_eos(vocab), llama_cpp.llama_token_eot(vocab)]: |
|
break |
|
|
|
print() |
|
|
|
|
|
def trans(text): |
|
test() |
|
|
|
yield "done" |
|
|
|
|
|
input_text = f"<2ja>{text}".encode('utf-8') |
|
|
|
|
|
tokens = llm.tokenize(input_text) |
|
print("Tokens:", tokens) |
|
|
|
|
|
bos_token = llm.token_bos() |
|
print("BOS Token:", bos_token) |
|
initial_tokens = [bos_token] |
|
initial_tokens = [1] |
|
print("Initial Tokens:", initial_tokens) |
|
|
|
|
|
buf = "" |
|
for token in llm.generate(initial_tokens, top_p=0.95, temp=0.0, repeat_penalty=1.0): |
|
decoded = llm.detokenize([token]).decode('utf-8', errors='ignore') |
|
buf += decoded |
|
if token == llm.token_eos(): |
|
break |
|
|
|
return buf |
|
|
|
|
|
input_text = f"<2ja>{text}".encode('utf-8') |
|
|
|
|
|
tokens = llm.tokenize(input_text) |
|
print("Tokens:", tokens) |
|
|
|
|
|
initial_tokens = [llm.token_bos()] |
|
|
|
|
|
buf = "" |
|
for token in llm.generate(initial_tokens, top_p=0.95, temp=0.0, repeat_penalty=1.0): |
|
decoded = llm.detokenize([token]).decode('utf-8', errors='ignore') |
|
buf += decoded |
|
if token == llm.token_eos(): |
|
break |
|
|
|
return buf |
|
|
|
|
|
input_text = f"<2ja>{text}".encode('utf-8') |
|
tokens = llm.tokenize(input_text) |
|
print("Tokens:", tokens) |
|
initial_tokens = [llm.decoder_start_token()] |
|
print("Initial Tokens:", initial_tokens) |
|
return text |
|
llama = llm |
|
text = f"<2ja>{text}".encode() |
|
tokens = llama.tokenize(text) |
|
llama.encode(tokens) |
|
tokens = [llama.decoder_start_token()] |
|
buf = "" |
|
for token in llama.generate(tokens, top_k=0, top_p=0.95, temp=0, repeat_penalty=1.0): |
|
buf += llama.detokenize([token]).decode() |
|
if token == llama.token_eos(): |
|
break |
|
return buf |
|
|
|
def respond( |
|
message: str, |
|
history: List[Tuple[str, str]], |
|
model: str, |
|
system_message: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
top_k: int, |
|
repeat_penalty: float, |
|
): |
|
""" |
|
Respond to a message using the Gemma3 model via Llama.cpp. |
|
|
|
Args: |
|
- message (str): The message to respond to. |
|
- history (List[Tuple[str, str]]): The chat history. |
|
- model (str): The model to use. |
|
- system_message (str): The system message to use. |
|
- max_tokens (int): The maximum number of tokens to generate. |
|
- temperature (float): The temperature of the model. |
|
- top_p (float): The top-p of the model. |
|
- top_k (int): The top-k of the model. |
|
- repeat_penalty (float): The repetition penalty of the model. |
|
|
|
Returns: |
|
str: The response to the message. |
|
""" |
|
try: |
|
|
|
global llm |
|
global llm_model |
|
|
|
|
|
|
|
if llm is None or llm_model != model: |
|
llm = Llama( |
|
model_path=f"models/{model}", |
|
flash_attn=False, |
|
n_gpu_layers=0, |
|
n_batch=8, |
|
n_ctx=2048, |
|
n_threads=8, |
|
n_threads_batch=8, |
|
) |
|
llm_model = model |
|
|
|
return trans(message) |
|
|
|
provider = LlamaCppPythonProvider(llm) |
|
|
|
|
|
agent = LlamaCppAgent( |
|
provider, |
|
system_prompt=f"{system_message}", |
|
|
|
custom_messages_formatter=gemma_3_formatter, |
|
debug_output=True, |
|
) |
|
|
|
|
|
settings = provider.get_provider_default_settings() |
|
settings.temperature = temperature |
|
settings.top_k = top_k |
|
settings.top_p = top_p |
|
settings.max_tokens = max_tokens |
|
settings.repeat_penalty = repeat_penalty |
|
settings.stream = True |
|
|
|
messages = BasicChatHistory() |
|
|
|
|
|
for msn in history: |
|
user = {"role": Roles.user, "content": msn[0]} |
|
assistant = {"role": Roles.assistant, "content": msn[1]} |
|
messages.add_message(user) |
|
messages.add_message(assistant) |
|
|
|
|
|
stream = agent.get_chat_response( |
|
message, |
|
llm_sampling_settings=settings, |
|
chat_history=messages, |
|
returns_streaming_generator=True, |
|
print_output=False, |
|
) |
|
|
|
|
|
logging.info("Response stream generated successfully") |
|
|
|
|
|
outputs = "" |
|
for output in stream: |
|
outputs += output |
|
yield outputs |
|
|
|
|
|
except Exception as e: |
|
|
|
raise CustomExceptionHandling(e, sys) from e |
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
examples=[["What is the capital of France?"], ["Tell me something about artificial intelligence."], ["What is gravity?"]], |
|
additional_inputs_accordion=gr.Accordion( |
|
label="⚙️ Parameters", open=False, render=False |
|
), |
|
additional_inputs=[ |
|
gr.Dropdown( |
|
choices=[ |
|
"madlad400-3b-mt-q8_0.gguf", |
|
], |
|
value="madlad400-3b-mt-q8_0.gguf", |
|
label="Model", |
|
info="Select the AI model to use for chat", |
|
), |
|
gr.Textbox( |
|
value="You are a helpful assistant.", |
|
label="System Prompt", |
|
info="Define the AI assistant's personality and behavior", |
|
lines=2, |
|
), |
|
gr.Slider( |
|
minimum=512, |
|
maximum=2048, |
|
value=1024, |
|
step=1, |
|
label="Max Tokens", |
|
info="Maximum length of response (higher = longer replies)", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
info="Creativity level (higher = more creative, lower = more focused)", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p", |
|
info="Nucleus sampling threshold", |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=40, |
|
step=1, |
|
label="Top-k", |
|
info="Limit vocabulary choices to top K tokens", |
|
), |
|
gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.1, |
|
label="Repetition Penalty", |
|
info="Penalize repeated words (higher = less repetition)", |
|
), |
|
], |
|
theme="Ocean", |
|
submit_btn="Send", |
|
stop_btn="Stop", |
|
title=title, |
|
description=description, |
|
chatbot=gr.Chatbot(scale=1, show_copy_button=True), |
|
flagging_mode="never", |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=False) |
|
test() |
|
|