File size: 3,010 Bytes
c566ded
 
 
 
 
 
967f284
29ac499
c566ded
967f284
c566ded
30169f7
c566ded
 
86ef0b6
 
30169f7
86ef0b6
 
41dc826
c566ded
30169f7
 
c566ded
30169f7
 
c566ded
 
 
30169f7
c566ded
 
5f14f54
c566ded
 
 
 
 
 
30169f7
 
c566ded
 
 
 
 
 
 
 
 
 
 
 
5f14f54
30169f7
5f14f54
c566ded
 
30169f7
c566ded
30169f7
c566ded
 
30169f7
 
c566ded
 
30169f7
c566ded
30169f7
 
 
c566ded
 
30169f7
 
c566ded
30169f7
c566ded
30169f7
c566ded
30169f7
 
 
c566ded
30169f7
 
c566ded
 
 
 
 
 
 
 
 
30169f7
 
 
c566ded
 
 
 
30169f7
 
 
c566ded
 
5f14f54
30169f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

#
# 1) Custom Pastel Gradient CSS, and force text to black
#
CUSTOM_CSS = """
.gradio-container {
    background: linear-gradient(to right, #FFDEE9, #B5FFFC);
    color: black; /* ensure text appears in black */
}
"""

#
# 2) Description: "Bonjour Dans le chat du consentement" in black
#    Also add a CPU notice in black if no GPU is found.
#
DESCRIPTION = """# Bonjour Dans le chat du consentement  
Mistral-7B Instruct Demo  
"""

if not torch.cuda.is_available():
    DESCRIPTION += "Running on CPU - This is likely too large to run effectively.\n"

MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

#
# 3) Load Mistral-7B Instruct (requires gating, GPU recommended)
#
if torch.cuda.is_available():
    model_id = "mistralai/Mistral-7B-Instruct-v0.3"
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )

def generate(
    message: str,
    chat_history: list[dict],
) -> Iterator[str]:
    """
    Minimal chat generation function: no sliders, no extra params.
    """
    conversation = [*chat_history, {"role": "user", "content": message}]

    # Convert conversation to tokens
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    # If it exceeds max token length, trim
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    input_ids = input_ids.to(model.device)

    # Use a streamer to yield tokens as they are generated
    streamer = TextIteratorStreamer(
        tokenizer,
        timeout=20.0,
        skip_prompt=True,
        skip_special_tokens=True
    )

    # Basic generation settings (feel free to adjust if you want)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=512,        # Adjust if you want more or fewer tokens
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
    )

    # Run generation in a background thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

#
# 4) Build the Chat Interface
#    - No additional sliders
#    - No pre-filled example questions
#
demo = gr.ChatInterface(
    fn=generate,
    description=DESCRIPTION,
    css=CUSTOM_CSS,
    examples=None,  # remove example prompts
    type="messages"
)

if __name__ == "__main__":
    demo.queue().launch()