File size: 2,534 Bytes
c566ded 967f284 29ac499 c566ded 967f284 c566ded 86ef0b6 eb04a36 86ef0b6 41dc826 30169f7 c566ded eb04a36 c566ded eb04a36 5f14f54 eb04a36 c566ded 30169f7 c566ded eb04a36 c566ded 30169f7 c566ded 30169f7 c566ded 30169f7 eb04a36 30169f7 c566ded 30169f7 eb04a36 30169f7 c566ded 5f14f54 30169f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
CUSTOM_CSS = """
.gradio-container {
background: linear-gradient(to right, #FFDEE9, #B5FFFC);
color: black;
}
"""
DESCRIPTION = """# Bonjour Dans le chat du consentement
Mistral-7B Instruct Demo
"""
MAX_INPUT_TOKEN_LENGTH = 4096 # just a default
# Define model/tokenizer at the top so they're visible in all scopes
tokenizer = None
model = None
# Try to load the model only if GPU is available
if torch.cuda.is_available():
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
else:
# Show a warning in the description
DESCRIPTION += "\n**Running on CPU** — This model is too large for CPU inference!"
def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
# If there's no GPU (thus no tokenizer/model), return an error
if tokenizer is None or model is None:
yield "Error: No GPU available. Unable to load Mistral-7B-Instruct."
return
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=20.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": 512,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate,
description=DESCRIPTION,
css=CUSTOM_CSS,
examples=None,
type="messages"
)
if __name__ == "__main__":
demo.queue().launch()
|