Mahavaury2's picture
Update app.py
eb04a36 verified
raw
history blame
2.53 kB
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
CUSTOM_CSS = """
.gradio-container {
background: linear-gradient(to right, #FFDEE9, #B5FFFC);
color: black;
}
"""
DESCRIPTION = """# Bonjour Dans le chat du consentement
Mistral-7B Instruct Demo
"""
MAX_INPUT_TOKEN_LENGTH = 4096 # just a default
# Define model/tokenizer at the top so they're visible in all scopes
tokenizer = None
model = None
# Try to load the model only if GPU is available
if torch.cuda.is_available():
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
else:
# Show a warning in the description
DESCRIPTION += "\n**Running on CPU** β€” This model is too large for CPU inference!"
def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
# If there's no GPU (thus no tokenizer/model), return an error
if tokenizer is None or model is None:
yield "Error: No GPU available. Unable to load Mistral-7B-Instruct."
return
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=20.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": 512,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate,
description=DESCRIPTION,
css=CUSTOM_CSS,
examples=None,
type="messages"
)
if __name__ == "__main__":
demo.queue().launch()