Spaces:
Runtime error
Runtime error
File size: 3,410 Bytes
8824f88 749d210 8824f88 b507d58 749d210 8824f88 b507d58 8824f88 7cb6017 8824f88 749d210 b507d58 8824f88 749d210 b507d58 8824f88 b507d58 749d210 8824f88 0737a9d 34353a1 0737a9d 8824f88 b507d58 8824f88 b507d58 c0a913b b507d58 8824f88 b507d58 749d210 8824f88 749d210 b507d58 8824f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Description de la démo
DESCRIPTION = "# Mistral-7B v0.3"
# Si pas de GPU détecté, afficher un message
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
if torch.cuda.is_available():
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
# Valeurs par défaut pour la génération
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""
Génération de texte à partir de l'historique de conversation (chat_history) et du message utilisateur.
"""
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=20.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# CSS pastel (dégradé sur toute la page)
custom_css = """
.gradio-container {
background: linear-gradient(135deg, #FDE2E2, #E2ECFD);
}
"""
# Liste des questions prédéfinies
predefined_examples = [
["1 - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
["2 - C’est quoi une agression sexuelle ?"],
["3 - C’est quoi un viol ?"],
["4 - C’est quoi un attouchement ?"],
["5 - C’est quoi un harcèlement sexuel ?"],
["6 - Est ce illégal de visionner du porno ?"],
["7 - Mon copain me demande un nude, dois-je le faire ?"],
["8 - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
["9 - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"],
]
# Création de l'interface de chat
demo = gr.ChatInterface(
fn=generate,
type="messages",
description=DESCRIPTION,
css=custom_css,
examples=predefined_examples,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|