Mahavaury2's picture
Update app.py
b723c84 verified
raw
history blame
3.12 kB
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = " "
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
if torch.cuda.is_available():
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
"""Fonction de génération sans sliders : les paramètres
de génération (max_new_tokens, température, etc.) sont
fixés en dur.
"""
# Valeurs par défaut fixées
max_new_tokens = 1024
temperature = 0.6
top_p = 0.9
top_k = 50
repetition_penalty = 1.2
# Prépare la conversation
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# On ne fournit plus 'additional_inputs' ici, donc aucun slider ne sera affiché
demo = gr.ChatInterface(
fn=generate,
stop_btn=None,
examples=[
["C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
["C’est quoi une agression sexuelle ?"],
["C’est quoi un viol ?"],
["C’est quoi un attouchement ?"],
["C’est quoi un harcèlement sexuel ?"],
["Est-ce illégal de visionner du porno ?"],
["C’est quoi un harcèlement sexuel ?"],
["Mon copain me demande un nude, dois-je le faire ?"],
["Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
["Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"
],
],
type="messages",
description=DESCRIPTION,
css_paths="style.css",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)