File size: 3,222 Bytes
8824f88
 
 
749d210
8824f88
 
 
 
 
 
 
2a7984e
 
 
a7f5e8b
8824f88
 
 
a7f5e8b
 
7cb6017
8824f88
 
749d210
b507d58
 
 
 
 
8824f88
 
 
 
 
749d210
8824f88
 
 
 
 
 
749d210
0737a9d
 
 
34353a1
0737a9d
8824f88
eed53c2
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f5e8b
b507d58
94c9a27
 
 
 
 
b507d58
 
 
a7f5e8b
 
 
 
 
 
 
 
 
 
 
 
eed53c2
a7f5e8b
 
 
 
 
 
 
8824f88
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """
<h1 style="color:black;">Mistral-7B v0.3</h1>
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

if torch.cuda.is_available():
    model_id = "mistralai/Mistral-7B-Instruct-v0.3"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = [*chat_history, {"role": "user", "content": message}]
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# CSS pour appliquer le dégradé pastel à TOUTE la page
custom_css = """
html, body {
    height: 100%;
    margin: 0;
    padding: 0;
    background: linear-gradient(135deg, #FDE2E2, #E2ECFD) !important;
}
"""

# Questions prédéfinies
predefined_examples = [
    ["1 - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
    ["2 - C’est quoi une agression sexuelle ?"],
    ["3 - C’est quoi un viol ?"],
    ["4 - C’est quoi un attouchement ?"],
    ["5 - C’est quoi un harcèlement sexuel ?"],
    ["6 - Est ce illégal de visionner du porno ?"],
    ["7 - Mon copain me demande un nude, dois-je le faire ?"],
    ["8 - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
    ["9 - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"],
]

demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    description=DESCRIPTION,
    css=custom_css,       # On applique le CSS pastel global
    examples=predefined_examples
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()