Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import os | |
from collections.abc import Iterator | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
DESCRIPTION = """ | |
<h1 style="color:black;">Mistral-7B v0.3</h1> | |
""" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
if torch.cuda.is_available(): | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
def generate( | |
message: str, | |
chat_history: list[dict], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [*chat_history, {"role": "user", "content": message}] | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
# CSS pour appliquer le dégradé pastel à TOUTE la page | |
custom_css = """ | |
html, body { | |
height: 100%; | |
margin: 0; | |
padding: 0; | |
background: linear-gradient(135deg, #FDE2E2, #E2ECFD) !important; | |
} | |
""" | |
# Questions prédéfinies | |
predefined_examples = [ | |
["1 - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"], # noqa: RUF001 | |
["2 - C’est quoi une agression sexuelle ?"], | |
["3 - C’est quoi un viol ?"], | |
["4 - C’est quoi un attouchement ?"], | |
["5 - C’est quoi un harcèlement sexuel ?"], | |
["6 - Est ce illégal de visionner du porno ?"], | |
["7 - Mon copain me demande un nude, dois-je le faire ?"], | |
["8 - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"], | |
[ | |
"9 - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?" | |
], | |
] | |
demo = gr.ChatInterface( | |
fn=generate, | |
type="messages", | |
description=DESCRIPTION, | |
css=custom_css, # On applique le CSS pastel global | |
examples=predefined_examples, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |