File size: 2,534 Bytes
c566ded
 
 
 
967f284
29ac499
c566ded
967f284
c566ded
86ef0b6
 
eb04a36
86ef0b6
 
41dc826
30169f7
 
c566ded
 
eb04a36
c566ded
eb04a36
 
 
5f14f54
eb04a36
c566ded
 
 
30169f7
 
c566ded
 
 
 
 
 
 
eb04a36
 
 
 
 
 
 
 
 
c566ded
 
 
 
 
 
30169f7
 
c566ded
 
 
30169f7
 
 
c566ded
 
30169f7
eb04a36
 
 
 
 
 
 
 
 
30169f7
c566ded
 
 
 
 
 
 
 
 
 
 
30169f7
eb04a36
30169f7
c566ded
 
5f14f54
30169f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

CUSTOM_CSS = """
.gradio-container {
    background: linear-gradient(to right, #FFDEE9, #B5FFFC);
    color: black;
}
"""

DESCRIPTION = """# Bonjour Dans le chat du consentement  
Mistral-7B Instruct Demo  
"""

MAX_INPUT_TOKEN_LENGTH = 4096  # just a default

# Define model/tokenizer at the top so they're visible in all scopes
tokenizer = None
model = None

# Try to load the model only if GPU is available
if torch.cuda.is_available():
    model_id = "mistralai/Mistral-7B-Instruct-v0.3"
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
else:
    # Show a warning in the description
    DESCRIPTION += "\n**Running on CPU** — This model is too large for CPU inference!"

def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
    # If there's no GPU (thus no tokenizer/model), return an error
    if tokenizer is None or model is None:
        yield "Error: No GPU available. Unable to load Mistral-7B-Instruct."
        return

    conversation = [*chat_history, {"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer,
        timeout=20.0,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "max_new_tokens": 512,
        "do_sample": True,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
    }

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

demo = gr.ChatInterface(
    fn=generate,
    description=DESCRIPTION,
    css=CUSTOM_CSS,
    examples=None,
    type="messages"
)

if __name__ == "__main__":
    demo.queue().launch()