import os
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# Global model & tokenizer
tokenizer = None
model = None
# Load selected model
def load_model(model_name):
global tokenizer, model
full_model_name = f"MaxLSB/{model_name}"
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return []
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
# Ensure the image is in a 'static' directory
image_path = os.path.join("static", "le-carnet.png")
with gr.Row(equal_height=True):
# Use the correct URL format for Gradio static files
gr.Markdown(
f'',
elem_classes="header-image"
)
gr.Markdown(
"