File size: 3,410 Bytes
8824f88
 
 
749d210
8824f88
 
 
 
 
 
 
b507d58
749d210
8824f88
b507d58
8824f88
 
 
 
 
7cb6017
8824f88
 
749d210
b507d58
 
 
 
 
8824f88
 
 
 
 
749d210
b507d58
8824f88
 
 
 
 
 
b507d58
 
 
749d210
8824f88
0737a9d
 
 
34353a1
0737a9d
8824f88
b507d58
 
 
 
 
 
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b507d58
 
 
 
 
 
 
 
 
c0a913b
 
 
 
 
 
 
 
 
b507d58
8824f88
b507d58
749d210
8824f88
749d210
 
b507d58
 
8824f88
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Description de la démo
DESCRIPTION = "# Mistral-7B v0.3"

# Si pas de GPU détecté, afficher un message
if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

if torch.cuda.is_available():
    model_id = "mistralai/Mistral-7B-Instruct-v0.3"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    # Valeurs par défaut pour la génération
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    """
    Génération de texte à partir de l'historique de conversation (chat_history) et du message utilisateur.
    """
    conversation = [*chat_history, {"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer,
        timeout=20.0,
        skip_prompt=True,
        skip_special_tokens=True
    )
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# CSS pastel (dégradé sur toute la page)
custom_css = """
.gradio-container {
    background: linear-gradient(135deg, #FDE2E2, #E2ECFD);
}
"""

# Liste des questions prédéfinies
predefined_examples = [
    ["1 - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
    ["2 - C’est quoi une agression sexuelle ?"],
    ["3 - C’est quoi un viol ?"],
    ["4 - C’est quoi un attouchement ?"],
    ["5 - C’est quoi un harcèlement sexuel ?"],
    ["6 - Est ce illégal de visionner du porno ?"],
    ["7 - Mon copain me demande un nude, dois-je le faire ?"],
    ["8 - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
    ["9 - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"],
]

# Création de l'interface de chat
demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    description=DESCRIPTION,
    css=custom_css,
    examples=predefined_examples,
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()