Spaces:
Sleeping
Sleeping
import os | |
import threading | |
import gradio as gr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
# Define your models | |
MODEL_PATHS = { | |
"LeCarnet-3M": "MaxLSB/LeCarnet-3M", | |
"LeCarnet-8M": "MaxLSB/LeCarnet-8M", | |
"LeCarnet-21M": "MaxLSB/LeCarnet-21M", | |
} | |
# Add your Hugging Face token | |
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if not hf_token: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
# Load tokenizers & models - only load one initially | |
tokenizer = None | |
model = None | |
def load_model(model_name): | |
"""Loads the specified model and tokenizer.""" | |
global tokenizer, model | |
if model_name not in MODEL_PATHS: | |
raise ValueError(f"Unknown model: {model_name}") | |
print(f"Loading {model_name}...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[model_name], token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_PATHS[model_name], token=hf_token) | |
model.eval() | |
print(f"{model_name} loaded.") | |
# Initial model load | |
initial_model = list(MODEL_PATHS.keys())[0] | |
load_model(initial_model) | |
def respond( | |
prompt: str, | |
chat_history, | |
model_choice: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
global tokenizer, model | |
# Reload model if it's not the currently loaded one | |
if model.config._name_or_path != MODEL_PATHS[model_choice]: | |
load_model(model_choice) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
skip_prompt=False, | |
skip_special_tokens=True, | |
) | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
accumulated = "" | |
for new_text in streamer: | |
accumulated += new_text | |
yield accumulated | |
# --- Gradio Interface --- | |
# CSS for the custom logo and layout | |
css = """ | |
.gradio-container { | |
padding: 0 !important; | |
} | |
.gradio-container > main.fillable { | |
padding: 0 !important; | |
} | |
#chatbot { | |
height: calc(100vh - 21px - 16px); | |
max-height: 1500px; | |
} | |
#chatbot .chatbot-conversations { | |
height: 100vh; | |
background-color: var(--ms-gr-ant-color-bg-layout); | |
padding-left: 4px; | |
padding-right: 4px; | |
} | |
#chatbot .chatbot-conversations .chatbot-conversations-list { | |
padding-left: 0; | |
padding-right: 0; | |
} | |
#chatbot .chatbot-chat { | |
padding: 32px; | |
padding-bottom: 0; | |
height: 100%; | |
} | |
@media (max-width: 768px) { | |
#chatbot .chatbot-chat { | |
padding: 0; | |
} | |
} | |
#chatbot .chatbot-chat .chatbot-chat-messages { | |
flex: 1; | |
} | |
.logo-container { | |
display: flex; | |
justify-content: center; | |
padding: 10px; | |
} | |
.logo-container img { | |
max-width: 80%; /* Adjust as needed */ | |
height: auto; | |
} | |
""" | |
with gr.Blocks(css=css, fill_width=True) as demo: | |
with gr.Column(elem_id="chatbot", variant="panel"): | |
# Custom Logo | |
with gr.Row(elem_classes="logo-container"): | |
gr.Image( | |
value="media/le-carnet.png", # Replace with the path to your image file | |
label="LeCarnet Logo", | |
interactive=False, | |
show_label=False, | |
show_download_button=False, | |
height=100 # Adjust height as needed | |
) | |
gr.Markdown( | |
""" | |
# LeCarnet AI Assistant | |
Type the beginning of a sentence and watch the model finish it. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_dropdown = gr.Dropdown( | |
choices=list(MODEL_PATHS.keys()), | |
value=initial_model, | |
label="Choose Model", | |
interactive=True | |
) | |
max_tokens_slider = gr.Slider( | |
1, 512, value=512, step=1, label="Max new tokens" | |
) | |
temperature_slider = gr.Slider( | |
0.1, 2.0, value=0.7, step=0.1, label="Temperature" | |
) | |
top_p_slider = gr.Slider( | |
0.1, 1.0, value=0.9, step=0.05, label="Top‑p" | |
) | |
with gr.Column(scale=3): | |
chatbot = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[ | |
model_dropdown, # Pass model choice to respond function | |
max_tokens_slider, | |
temperature_slider, | |
top_p_slider, | |
], | |
examples=[ | |
["Il était une fois un petit garçon qui vivait dans un village paisible."], | |
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."], | |
["Il était une fois un petit lapin perdu"], | |
], | |
cache_examples=False, | |
submit_btn="Generate", | |
clear_btn="Clear Chat", | |
) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |