|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import json |
|
import subprocess |
|
import sys |
|
from llama_cpp import Llama,llama_model_decoder_start_token |
|
from llama_cpp_agent import LlamaCppAgent |
|
from llama_cpp_agent import MessagesFormatterType |
|
from llama_cpp_agent.providers import LlamaCppPythonProvider |
|
from llama_cpp_agent.chat_history import BasicChatHistory |
|
from llama_cpp_agent.chat_history.messages import Roles |
|
from llama_cpp_agent.chat_history.messages import Roles |
|
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from typing import List, Tuple |
|
from logger import logging |
|
from exception import CustomExceptionHandling |
|
|
|
|
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
os.makedirs("models",exist_ok=True) |
|
|
|
|
|
|
|
hf_hub_download( |
|
repo_id="pszemraj/flan-t5-large-grammar-synthesis", |
|
filename="ggml-model-Q6_K.gguf", |
|
local_dir="./models", |
|
) |
|
|
|
|
|
|
|
|
|
title = "flan-t5-large-grammar-synthesis Llama.cpp" |
|
description = """ |
|
I'm using [fairydreaming/T5-branch](https://github.com/fairydreaming/llama-cpp-python/tree/t5), I'm not sure current llama-cpp-python server support t5 |
|
|
|
[Model-Q6_K-GGUF](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis-gguf), [Reference1](https://huggingface.co/spaces/sitammeur/Gemma-llamacpp) |
|
""" |
|
|
|
|
|
llama = None |
|
|
|
|
|
import ctypes |
|
import os |
|
import multiprocessing |
|
|
|
import llama_cpp |
|
|
|
def respond( |
|
message: str, |
|
history: List[Tuple[str, str]], |
|
model: str, |
|
system_message: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
top_k: int, |
|
repeat_penalty: float, |
|
): |
|
""" |
|
Respond to a message using the Gemma3 model via Llama.cpp. |
|
|
|
Args: |
|
- message (str): The message to respond to. |
|
- history (List[Tuple[str, str]]): The chat history. |
|
- model (str): The model to use. |
|
- system_message (str): The system message to use. |
|
- max_tokens (int): The maximum number of tokens to generate. |
|
- temperature (float): The temperature of the model. |
|
- top_p (float): The top-p of the model. |
|
- top_k (int): The top-k of the model. |
|
- repeat_penalty (float): The repetition penalty of the model. |
|
|
|
Returns: |
|
str: The response to the message. |
|
""" |
|
if model == None: |
|
return |
|
try: |
|
global llama |
|
if llama == None: |
|
model_id = "ggml-model-Q6_K.gguf" |
|
llama = Llama(f"models/{model_id}",flash_attn=False, |
|
n_gpu_layers=0, |
|
n_ctx=max_tokens, |
|
n_threads=2, |
|
n_threads_batch=2,verbose=False) |
|
|
|
tokens = llama.tokenize(f"{message}".encode("utf-8")) |
|
llama.encode(tokens) |
|
tokens = [llama.decoder_start_token()] |
|
outputs ="" |
|
iteration = 1 |
|
for i in range(iteration): |
|
for token in llama.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repeat_penalty): |
|
outputs+= llama.detokenize([token]).decode() |
|
yield outputs |
|
if token == llama.token_eos(): |
|
break |
|
|
|
return outputs |
|
except Exception as e: |
|
|
|
raise CustomExceptionHandling(e, sys) from e |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
examples=[["What are the capital of France?"], ["What real child was raise by wolves?"], ["What am gravity?"]], |
|
additional_inputs_accordion=gr.Accordion( |
|
label="⚙️ Parameters", open=False, render=False |
|
), |
|
additional_inputs=[ |
|
gr.Dropdown( |
|
choices=[ |
|
"ggml-model-Q6_K.gguf", |
|
], |
|
value="ggml-model-Q6_K.gguf", |
|
label="Model", |
|
info="Select the AI model to use for chat", |
|
visible=False |
|
), |
|
gr.Textbox( |
|
value="You are a helpful assistant.", |
|
label="System Prompt", |
|
info="Define the AI assistant's personality and behavior", |
|
lines=2,visible=False |
|
), |
|
gr.Slider( |
|
minimum=512, |
|
maximum=512, |
|
value=512, |
|
step=1, |
|
label="Max Tokens", |
|
info="Maximum length of response (higher = longer replies)",visible=False |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.4, |
|
step=0.1, |
|
label="Temperature", |
|
info="Creativity level (higher = more creative, lower = more focused)", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p", |
|
info="Nucleus sampling threshold", |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=40, |
|
step=1, |
|
label="Top-k", |
|
info="Limit vocabulary choices to top K tokens", |
|
), |
|
gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.1, |
|
label="Repetition Penalty", |
|
info="Penalize repeated words (higher = less repetition)", |
|
), |
|
], |
|
theme="Ocean", |
|
submit_btn="Send", |
|
stop_btn="Stop", |
|
title=title, |
|
description=description, |
|
chatbot=gr.Chatbot(scale=1, show_copy_button=True), |
|
flagging_mode="never", |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=False) |
|
test() |
|
|