Spaces:
Runtime error
Runtime error
File size: 3,222 Bytes
8824f88 749d210 8824f88 2a7984e a7f5e8b 8824f88 a7f5e8b 7cb6017 8824f88 749d210 b507d58 8824f88 749d210 8824f88 749d210 0737a9d 34353a1 0737a9d 8824f88 eed53c2 8824f88 a7f5e8b b507d58 94c9a27 b507d58 a7f5e8b eed53c2 a7f5e8b 8824f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """
<h1 style="color:black;">Mistral-7B v0.3</h1>
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
if torch.cuda.is_available():
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# CSS pour appliquer le dégradé pastel à TOUTE la page
custom_css = """
html, body {
height: 100%;
margin: 0;
padding: 0;
background: linear-gradient(135deg, #FDE2E2, #E2ECFD) !important;
}
"""
# Questions prédéfinies
predefined_examples = [
["1 - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
["2 - C’est quoi une agression sexuelle ?"],
["3 - C’est quoi un viol ?"],
["4 - C’est quoi un attouchement ?"],
["5 - C’est quoi un harcèlement sexuel ?"],
["6 - Est ce illégal de visionner du porno ?"],
["7 - Mon copain me demande un nude, dois-je le faire ?"],
["8 - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
["9 - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"],
]
demo = gr.ChatInterface(
fn=generate,
type="messages",
description=DESCRIPTION,
css=custom_css, # On applique le CSS pastel global
examples=predefined_examples
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|